├── .classpath ├── .gitignore ├── .project ├── LICENSE ├── MOA-dependencies.jar ├── README.md ├── ROSE-1.0.jar ├── pom.xml ├── sizeofag-1.0.4.jar └── src └── main └── java ├── experiments ├── Collect_Results.java ├── Datasets.java ├── Drifting_Imbalance_Ratio.java ├── Drifting_Noise_and_Imbalance_Ratio.java ├── Instance_Level_Difficulties.java └── Static_Imbalance_Ratio.java └── moa ├── classifiers ├── meta │ └── imbalanced │ │ └── ROSE.java └── trees │ └── RandomSubspaceHT.java ├── evaluation ├── WindowAUCImbalancedPerformanceEvaluator.java ├── WindowAUCMultiClassImbalancedPerformanceEvaluator.java └── WindowImbalancedClassificationPerformanceEvaluator.java └── streams ├── filters └── AddNoiseFilterFeatures.java └── generators └── imbalanced ├── AgrawalGenerator.java ├── AssetNegotiationGenerator.java ├── HyperplaneGenerator.java ├── MixedGenerator.java ├── RandomRBFGenerator.java ├── RandomRBFGeneratorDrift.java ├── RandomTreeGenerator.java ├── SEAGenerator.java ├── STAGGERGenerator.java ├── SineGenerator.java └── TextGenerator.java /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | /datasets/ 3 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | ROSE 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.m2e.core.maven2Builder 15 | 16 | 17 | 18 | 19 | 20 | org.eclipse.jdt.core.javanature 21 | org.eclipse.m2e.core.maven2Nature 22 | 23 | 24 | -------------------------------------------------------------------------------- /MOA-dependencies.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/MOA-dependencies.jar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning from Imbalanced Drifting Data Streams 2 | 3 |

Data streams are potentially unbounded sequences of instances arriving over time to a classifier. Designing algorithms that are capable of dealing with massive, rapidly arriving information is one of the most dynamically developing areas of machine learning. Such learners must be able to deal with a phenomenon known as concept drift, where the data stream may be subject to various changes in its characteristics over time. Furthermore, distributions of classes may evolve over time, leading to a highly difficult non-stationary class imbalance. In this work we introduce Robust Online Self-Adjusting Ensemble (ROSE), a novel online ensemble classifier capable of dealing with all of the mentioned challenges. The main features of ROSE are: (i) online training of base classifiers on variable size random subsets of features; (ii) online detection of concept drift and creation of a background ensemble for faster adaptation to changes; (iii) sliding window per class to create skew-insensitive classifiers regardless of the current imbalance ratio; and (iv) self-adjusting bagging to enhance the exposure of difficult instances from minority classes. The interplay among these features leads to an improved performance in various data stream mining benchmarks. An extensive experimental study comparing with 30 ensemble classifiers shows that ROSE is a robust and well-rounded classifier for drifting imbalanced data streams, especially under the presence of noise and class imbalance drift, while maintaining competitive time complexity and memory consumption. Results are supported by a thorough non-parametric statistical analysis.

4 | 5 | ## Using ROSE 6 | 7 | Download the pre-compiled jar files or import the project source code into [MOA](https://github.com/Waikato/moa). See the src/main/java/experiments folder to reproduce our research. 8 | 9 | ### Experiment 1: Static imbalance ratio 10 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100. 11 | ``` 12 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 1 -m 0.01)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv 13 | ``` 14 | 15 | | Generator | Instances | Features | Classes | Static Imbalance Ratios | Concept Drift | 16 | | -------- | ---: | ---: | ---: |:-: |:-: | 17 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 50, 100} | None | 18 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 50, 100} | None | 19 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 50, 100} | None | 20 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 50, 100} | None | 21 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 50, 100} | None | 22 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 50, 100} | None | 23 | 24 | ### Experiment 2: Drifting imbalance ratio 25 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100. Generate drifting imbalance ratios by chaining `ConceptDriftStream` streams with different imbalance ratios. The parameter `-p` controls the position of the drift and `-w` the width of the drift (sudden vs gradual). The example shows a sequence of increasing then decreasing imbalance ratio ({5, 10, 20, 100, 20, 10, 5}). 26 | ``` 27 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 -d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv 28 | ``` 29 | 30 | | Generator | Instances | Features | Classes | Drifting imbalance ratios | Concept Drift | 31 | | -------- | ---: | ---: | ---: |:-: |:-: | 32 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 33 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 34 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 35 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 36 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 37 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | 38 | 39 | ### Experiment 3: Instance-level difficulties 40 | Use any algorithm in `moa.classifiers` and dataset for instance-level difficulties generated using these imbalanced generators 41 | 42 | ``` 43 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ArffFileStream -f Split5+Im1+Borderline20+Rare20.arff)" -l "(moa.classifiers.meta.OSAKUE_20200935_t99_scaleacckappa)" -f 500 -d results.csv 44 | ``` 45 | 46 | | Generator | Instances | Features | Classes | Static Imbalance Ratios | Percentage of difficult instances | 47 | | -------- | ---: | ---: | ---: |:-: |:-- | 48 | | Borderline | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%, 60%, 80%, 100%} | 49 | | Rare | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%, 60%, 80%, 100%} | 50 | | Borderline + Rare | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%} | 51 | 52 | ### Experiment 4: Robustness to noise drift 53 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-f` controls the percentage of features with noise. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100. Generate drifting noise and imbalance ratios by chaining `ConceptDriftStream` streams with different imbalance ratios, percentages of features with noise, and noise seed `-r`. The parameter `-p` controls the position of the drift and `-w` the width of the drift (sudden vs gradual). The example shows a sequence of drifting noise to other features and increasing then decreasing imbalance ratio ({5, 10, 20, 100, 20, 10, 5}). 54 | ``` 55 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -f (AddNoiseFilterFeatures -r 1 -a 0.99 -f 0.40)) -r 1 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -f (AddNoiseFilterFeatures -r 2 -a 0.99 -f 0.40)) -r 2 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -f (AddNoiseFilterFeatures -r 3 -a 0.99 -f 0.40)) -r 3 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -f (AddNoiseFilterFeatures -r 4 -a 0.99 -f 0.40)) -r 4 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -f (AddNoiseFilterFeatures -r 5 -a 0.99 -f 0.40)) -r 5 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -f (AddNoiseFilterFeatures -r 6 -a 0.99 -f 0.40)) -r 6 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -f (AddNoiseFilterFeatures -r 7 -a 0.99 -f 0.40)) -r 7 -d (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -f (AddNoiseFilterFeatures -r 8 -a 0.99 -f 0.40)) -r 8 -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv 56 | ``` 57 | 58 | | Generator | Instances | Features | Classes | Drifting imbalance ratios | Concept Drift | Percentage of features with noise 59 | | -------- | ---: | ---: | ---: |:-: |:-: | :-- | 60 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 61 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 62 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 63 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 64 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 65 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%} 66 | 67 | ### Experiment 5: Datasets 68 | Use any algorithm in `moa.classifiers` and dataset from UCI / KEEL dataset repositories. 69 | ``` 70 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCMultiClassImbalancedPerformanceEvaluator)" -s "(ArffFileStream -f dataset.arff)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -f 500 -d results.csv 71 | ``` 72 | 73 | | Dataset | Instances | Features | Classes | 74 | | -------- | ---: | ---: | ---: | 75 | | adult | 45,222 | 14 | 2 | 76 | | airlines | 539,383 | 7 | 2 | 77 | | bridges | 1,000,000 | 12 | 6 | 78 | | census | 299,284 | 41 | 2 | 79 | | coil2000 | 9,822 | 85 | 2 | 80 | | connect-4 | 67,557 | 42 | 3 | 81 | | covtype | 581,012 | 54 | 7 | 82 | | dj30 | 138,166 | 7 | 30 | 83 | | electricity | 45,312 | 8 | 2 | 84 | | fars | 100,968 | 29 | 8 | 85 | | gas-sensor | 13,910 | 128 | 6 | 86 | | gmsc | 150,000 | 10 | 2 | 87 | | intel-lab | 2,313,153 | 5 | 58 | 88 | | kddcup | 4,898,431 | 41 | 23 | 89 | | kr-vs-k | 28,056 | 6 | 18 | 90 | | letter | 20,000 | 16 | 26 | 91 | | magic | 19,020 | 10 | 2 | 92 | | nomao | 34,465 | 118 | 2 | 93 | | penbased | 10,992 | 16 | 10 | 94 | | poker | 829,201 | 10 | 10 | 95 | | powersupply | 29,928 | 2 | 24 | 96 | | shuttle | 57,999 | 9 | 7 | 97 | | thyroid | 7,200 | 21 | 3 | 98 | | zoo | 1,000,000 | 17 | 7 | 99 | 100 | ## Citation 101 | ``` 102 | @article{cano2022rose, 103 | title={{ROSE: robust online self-adjusting ensemble for continual learning on imbalanced drifting data streams}}, 104 | author={Cano, Alberto and Krawczyk, Bartosz}, 105 | journal={Machine Learning}, 106 | volume={111}, 107 | number={7}, 108 | pages={2561--2599}, 109 | year={2022} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /ROSE-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/ROSE-1.0.jar -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | edu.vcu.acano 7 | ROSE 8 | 1.0 9 | 10 | ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams 11 | https://github.com/canoalberto/ROSE 12 | 13 | 14 | Virginia Commonwealth University, Richmond, Virginia, USA 15 | http://www.vcu.edu/ 16 | 17 | 18 | 19 | 20 | GNU General Public License 3.0 21 | http://www.gnu.org/licenses/gpl-3.0.txt 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | canoalberto 29 | Alberto Cano 30 | acano@vcu.edu 31 | 32 | 33 | 34 | 35 | UTF-8 36 | 1.11 37 | 1.11 38 | 39 | 40 | 41 | 42 | junit 43 | junit 44 | 4.13.2 45 | test 46 | 47 | 48 | nz.ac.waikato.cms.moa 49 | moa 50 | 2021.07.0 51 | 52 | 53 | nz.ac.waikato.cms.weka 54 | weka-dev 55 | 3.9.5 56 | 57 | 58 | org.apache.commons 59 | commons-math3 60 | 3.6.1 61 | 62 | 63 | gov.nist.math 64 | jama 65 | 1.0.3 66 | 67 | 68 | 69 | 70 | 71 | 72 | maven-assembly-plugin 73 | 74 | 75 | 76 | moa.DoTask 77 | 78 | 79 | 80 | jar-with-dependencies 81 | 82 | 83 | 84 | 85 | make-assembly 86 | package 87 | 88 | single 89 | 90 | 91 | 92 | 93 | 94 | org.apache.maven.plugins 95 | maven-compiler-plugin 96 | 3.8.1 97 | 98 | 11 99 | 11 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /sizeofag-1.0.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/sizeofag-1.0.4.jar -------------------------------------------------------------------------------- /src/main/java/experiments/Collect_Results.java: -------------------------------------------------------------------------------- 1 | package experiments; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileReader; 6 | 7 | public class Collect_Results { 8 | 9 | public static void main(String[] args) throws Exception { 10 | 11 | String resultsPath = "results_datasets"; 12 | 13 | // Replace with the collection of files (datasets / generators) to parse 14 | String[] files = new String[] { 15 | "thyroid", 16 | "coil2000", 17 | "penbased", 18 | "gas-sensor", 19 | "magic", 20 | "letter", 21 | "kr-vs-k", 22 | "powersupply", 23 | "adult", 24 | "electricity", 25 | "shuttle", 26 | "connect-4", 27 | "fars", 28 | "dj30", 29 | "census", 30 | "airlines", 31 | "covtype", 32 | "poker", 33 | "bridges", 34 | "zoo", 35 | "IntelLabSensors", 36 | "kddcup", 37 | "GMSC", 38 | "nomao", 39 | }; 40 | 41 | String[] algorithms = new String[] { 42 | "moa.classifiers.meta.imbalanced.ROSE", 43 | "moa.classifiers.meta.KUE", 44 | "moa.classifiers.meta.AccuracyWeightedEnsemble", 45 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1", 46 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2", 47 | "moa.classifiers.meta.DynamicWeightedMajority", 48 | "moa.classifiers.meta.SAE2", 49 | "moa.classifiers.meta.DACC", 50 | "moa.classifiers.meta.ADACC", 51 | "moa.classifiers.meta.AdaptiveRandomForest", 52 | "moa.classifiers.meta.ADOB", 53 | "moa.classifiers.meta.BOLE", 54 | "moa.classifiers.meta.GOOWE", 55 | "moa.classifiers.meta.HeterogeneousEnsembleBlast", 56 | "moa.classifiers.meta.LeveragingBag", 57 | "moa.classifiers.meta.OCBoost", 58 | "moa.classifiers.meta.OzaBag", 59 | "moa.classifiers.meta.OzaBagAdwin", 60 | "moa.classifiers.meta.OzaBagASHT", 61 | "moa.classifiers.meta.OzaBoost", 62 | "moa.classifiers.meta.OzaBoostAdwin", 63 | "moa.classifiers.meta.StreamingRandomPatches", 64 | "moa.classifiers.meta.UOB", 65 | "moa.classifiers.meta.OOB", 66 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging", 67 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging", 68 | "moa.classifiers.meta.imbalanced.CSMOTE", 69 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost", 70 | "moa.classifiers.meta.imbalanced.OnlineAdaC2", 71 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost", 72 | "moa.classifiers.meta.imbalanced.RebalanceStream", 73 | }; 74 | 75 | String[] algorithmsFilename = new String[algorithms.length]; 76 | 77 | for(int alg = 0; alg < algorithms.length; alg++) { 78 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", ""); 79 | } 80 | 81 | metric("Accuracy", "averaged", resultsPath, files, algorithmsFilename); 82 | metric("Kappa", "averaged", resultsPath,files, algorithmsFilename); 83 | metric("AUC", "averaged", resultsPath, files, algorithmsFilename); // do not use for multi-class datasets 84 | metric("PMAUC", "averaged", resultsPath, files, algorithmsFilename); // use for multi-class datasets 85 | } 86 | 87 | private static void metric(String metricName, String outcome, String resultsPath, String[] files, String[] algorithms) throws Exception { 88 | 89 | System.out.print(metricName + "\t"); 90 | 91 | for(String algorithm : algorithms) 92 | System.out.print(algorithm + "\t"); 93 | System.out.println(""); 94 | 95 | for(String file : files) 96 | { 97 | System.out.print(file + "\t"); 98 | 99 | for(int alg = 0; alg < algorithms.length; alg++) 100 | { 101 | int count = 0; 102 | double sum = 0; 103 | double lastValue = 0; 104 | 105 | String filename = resultsPath + "/" + algorithms[alg] + "-" + file + ".csv"; 106 | 107 | if(new File(filename).exists()) 108 | { 109 | BufferedReader br = new BufferedReader(new FileReader(new File(filename))); 110 | 111 | String line; 112 | line = br.readLine(); // header line 113 | 114 | String[] columns = null; 115 | 116 | try { 117 | columns = line.split(","); 118 | } catch (Exception e) { 119 | e.printStackTrace(); 120 | System.exit(-1); 121 | } 122 | 123 | int index = -1; 124 | for(int i = 0; i < columns.length; i++) 125 | if(columns[i].equals(metricName)) 126 | index = i; 127 | 128 | while((line = br.readLine()) != null) 129 | { 130 | try { 131 | if(!line.split(",")[index].equals("?")) 132 | { 133 | lastValue = Double.parseDouble(line.split(",")[index]); 134 | sum += lastValue; 135 | count++; 136 | } 137 | } catch (Exception e) { 138 | lastValue = 0; 139 | sum += lastValue; 140 | count++; 141 | } 142 | } 143 | 144 | br.close(); 145 | } 146 | 147 | if(outcome.equalsIgnoreCase("averaged")) 148 | System.out.print((sum/count) + "\t"); 149 | else 150 | System.out.print(lastValue + "\t"); 151 | } 152 | 153 | System.out.println(""); 154 | } 155 | } 156 | } -------------------------------------------------------------------------------- /src/main/java/experiments/Datasets.java: -------------------------------------------------------------------------------- 1 | package experiments; 2 | 3 | import org.apache.commons.lang3.SystemUtils; 4 | 5 | public class Datasets { 6 | 7 | public static void main(String[] args) throws Exception { 8 | 9 | // Download datasets from https://people.vcu.edu/~acano/ROSE/datasets.zip 10 | 11 | String[] datasets = new String[] { 12 | "thyroid", 13 | "coil2000", 14 | "penbased", 15 | "gas-sensor", 16 | "magic", 17 | "letter", 18 | "kr-vs-k", 19 | "powersupply", 20 | "adult", 21 | "electricity", 22 | "shuttle", 23 | "connect-4", 24 | "fars", 25 | "dj30", 26 | "census", 27 | "airlines", 28 | "covtype", 29 | "poker", 30 | "bridges", 31 | "zoo", 32 | "IntelLabSensors", 33 | "kddcup", 34 | "GMSC", 35 | "nomao", 36 | }; 37 | 38 | String[] algorithms = new String[] { 39 | "moa.classifiers.meta.imbalanced.ROSE", 40 | "moa.classifiers.meta.KUE", 41 | "moa.classifiers.meta.AccuracyWeightedEnsemble", 42 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1", 43 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2", 44 | "moa.classifiers.meta.DynamicWeightedMajority", 45 | "moa.classifiers.meta.SAE2", 46 | "moa.classifiers.meta.DACC", 47 | "moa.classifiers.meta.ADACC", 48 | "moa.classifiers.meta.AdaptiveRandomForest", 49 | "moa.classifiers.meta.ADOB", 50 | "moa.classifiers.meta.BOLE", 51 | "moa.classifiers.meta.GOOWE", 52 | "moa.classifiers.meta.HeterogeneousEnsembleBlast", 53 | "moa.classifiers.meta.LeveragingBag", 54 | "moa.classifiers.meta.OCBoost", 55 | "moa.classifiers.meta.OzaBag", 56 | "moa.classifiers.meta.OzaBagAdwin", 57 | "moa.classifiers.meta.OzaBagASHT", 58 | "moa.classifiers.meta.OzaBoost", 59 | "moa.classifiers.meta.OzaBoostAdwin", 60 | "moa.classifiers.meta.StreamingRandomPatches", 61 | "moa.classifiers.meta.UOB", 62 | "moa.classifiers.meta.OOB", 63 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging", 64 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging", 65 | "moa.classifiers.meta.imbalanced.CSMOTE", 66 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost", 67 | "moa.classifiers.meta.imbalanced.OnlineAdaC2", 68 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost", 69 | "moa.classifiers.meta.imbalanced.RebalanceStream", 70 | }; 71 | 72 | String[] algorithmsFilename = new String[algorithms.length]; 73 | 74 | for(int alg = 0; alg < algorithms.length; alg++) { 75 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", ""); 76 | } 77 | 78 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";"; 79 | 80 | for(int dat = 0; dat < datasets.length; dat++) 81 | { 82 | for(int alg = 0; alg < algorithms.length; alg++) 83 | { 84 | // Replace evaluator with WindowAUCMultiClassImbalancedPerformanceEvaluator for multi-class datasets 85 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar " 86 | + "moa.DoTask EvaluateInterleavedTestThenTrain" 87 | + " -e \"(WindowAUCMultiClassImbalancedPerformanceEvaluator)\"" 88 | + " -s \"(ArffFileStream -f datasets/" + datasets[dat] + ".arff)\"" 89 | + " -l \"(" + algorithms[alg] + ")\"" 90 | + " -f 500" 91 | + " -d results_datasets/" + algorithmsFilename[alg] + "-" + datasets[dat] + ".csv"); 92 | } 93 | } 94 | } 95 | } -------------------------------------------------------------------------------- /src/main/java/experiments/Drifting_Imbalance_Ratio.java: -------------------------------------------------------------------------------- 1 | package experiments; 2 | 3 | import org.apache.commons.lang3.SystemUtils; 4 | 5 | public class Drifting_Imbalance_Ratio { 6 | 7 | public static void main(String[] args) throws Exception { 8 | 9 | String[] generators = new String[] { 10 | // Sudden drift 11 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 " 12 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 " 13 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 " 14 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 " 15 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 " 16 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 " 17 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 " 18 | + "-d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 " 19 | + "-p 125000 -w 1) " 20 | + "-p 125000 -w 1) " 21 | + "-p 125000 -w 1) " 22 | + "-p 125000 -w 1) " 23 | + "-p 125000 -w 1) " 24 | + "-p 125000 -w 1) " 25 | + "-p 125000 -w 1", 26 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 1 -m 0.2) -r 1 " 27 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 2 -f 1 -m 0.1) -r 2 " 28 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 3 -f 1 -m 0.05) -r 3 " 29 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 4 -f 1 -m 0.01) -r 4 " 30 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 5 -f 1 -m 0.01) -r 5 " 31 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 6 -f 1 -m 0.05) -r 6 " 32 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 7 -f 1 -m 0.1) -r 7 " 33 | + "-d (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 8 -f 1 -m 0.2) -r 8 " 34 | + "-p 125000 -w 1) " 35 | + "-p 125000 -w 1) " 36 | + "-p 125000 -w 1) " 37 | + "-p 125000 -w 1) " 38 | + "-p 125000 -w 1) " 39 | + "-p 125000 -w 1) " 40 | + "-p 125000 -w 1", 41 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -m 0.2) -r 1 " 42 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 2 -a 10 -c 2 -m 0.1) -r 2 " 43 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 3 -a 10 -c 2 -m 0.05) -r 3 " 44 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 4 -a 10 -c 2 -m 0.01) -r 4 " 45 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 5 -a 10 -c 2 -m 0.01) -r 5 " 46 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 6 -a 10 -c 2 -m 0.05) -r 6 " 47 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 7 -a 10 -c 2 -m 0.1) -r 7 " 48 | + "-d (moa.streams.generators.imbalanced.HyperplaneGenerator -i 8 -a 10 -c 2 -m 0.2) -r 8 " 49 | + "-p 125000 -w 1) " 50 | + "-p 125000 -w 1) " 51 | + "-p 125000 -w 1) " 52 | + "-p 125000 -w 1) " 53 | + "-p 125000 -w 1) " 54 | + "-p 125000 -w 1) " 55 | + "-p 125000 -w 1", 56 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 10 -c 2 -m 0.2) -r 1 " 57 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 2 -r 2 -a 10 -c 2 -m 0.1) -r 2 " 58 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 3 -r 3 -a 10 -c 2 -m 0.05) -r 3 " 59 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 4 -r 4 -a 10 -c 2 -m 0.01) -r 4 " 60 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 5 -r 5 -a 10 -c 2 -m 0.01) -r 5 " 61 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 6 -r 6 -a 10 -c 2 -m 0.05) -r 6 " 62 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 7 -r 7 -a 10 -c 2 -m 0.1) -r 7 " 63 | + "-d (moa.streams.generators.imbalanced.RandomRBFGenerator -i 8 -r 8 -a 10 -c 2 -m 0.2) -r 8 " 64 | + "-p 125000 -w 1) " 65 | + "-p 125000 -w 1) " 66 | + "-p 125000 -w 1) " 67 | + "-p 125000 -w 1) " 68 | + "-p 125000 -w 1) " 69 | + "-p 125000 -w 1) " 70 | + "-p 125000 -w 1", 71 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2) -r 1 " 72 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 2 -f 1 -m 0.1) -r 2 " 73 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 3 -f 1 -m 0.05) -r 3 " 74 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 4 -f 1 -m 0.01) -r 4 " 75 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 5 -f 1 -m 0.01) -r 5 " 76 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 6 -f 1 -m 0.05) -r 6 " 77 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 7 -f 1 -m 0.1) -r 7 " 78 | + "-d (moa.streams.generators.imbalanced.SEAGenerator -i 8 -f 1 -m 0.2) -r 8 " 79 | + "-p 125000 -w 1) " 80 | + "-p 125000 -w 1) " 81 | + "-p 125000 -w 1) " 82 | + "-p 125000 -w 1) " 83 | + "-p 125000 -w 1) " 84 | + "-p 125000 -w 1) " 85 | + "-p 125000 -w 1", 86 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2) -r 1 " 87 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 2 -f 1 -m 0.1) -r 2 " 88 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 3 -f 1 -m 0.05) -r 3 " 89 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 4 -f 1 -m 0.01) -r 4 " 90 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 5 -f 1 -m 0.01) -r 5 " 91 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 6 -f 1 -m 0.05) -r 6 " 92 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 7 -f 1 -m 0.1) -r 7 " 93 | + "-d (moa.streams.generators.imbalanced.SineGenerator -i 8 -f 1 -m 0.2) -r 8 " 94 | + "-p 125000 -w 1) " 95 | + "-p 125000 -w 1) " 96 | + "-p 125000 -w 1) " 97 | + "-p 125000 -w 1) " 98 | + "-p 125000 -w 1) " 99 | + "-p 125000 -w 1) " 100 | + "-p 125000 -w 1", 101 | // Gradual drift 102 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 " 103 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 " 104 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 " 105 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 " 106 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 " 107 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 " 108 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 " 109 | + "-d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 " 110 | + "-p 125000 -w 50000) " 111 | + "-p 125000 -w 50000) " 112 | + "-p 125000 -w 50000) " 113 | + "-p 125000 -w 50000) " 114 | + "-p 125000 -w 50000) " 115 | + "-p 125000 -w 50000) " 116 | + "-p 125000 -w 50000", 117 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 1 -m 0.2) -r 1 " 118 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 2 -f 1 -m 0.1) -r 2 " 119 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 3 -f 1 -m 0.05) -r 3 " 120 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 4 -f 1 -m 0.01) -r 4 " 121 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 5 -f 1 -m 0.01) -r 5 " 122 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 6 -f 1 -m 0.05) -r 6 " 123 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 7 -f 1 -m 0.1) -r 7 " 124 | + "-d (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 8 -f 1 -m 0.2) -r 8 " 125 | + "-p 125000 -w 50000) " 126 | + "-p 125000 -w 50000) " 127 | + "-p 125000 -w 50000) " 128 | + "-p 125000 -w 50000) " 129 | + "-p 125000 -w 50000) " 130 | + "-p 125000 -w 50000) " 131 | + "-p 125000 -w 50000", 132 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -m 0.2) -r 1 " 133 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 2 -a 10 -c 2 -m 0.1) -r 2 " 134 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 3 -a 10 -c 2 -m 0.05) -r 3 " 135 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 4 -a 10 -c 2 -m 0.01) -r 4 " 136 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 5 -a 10 -c 2 -m 0.01) -r 5 " 137 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 6 -a 10 -c 2 -m 0.05) -r 6 " 138 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 7 -a 10 -c 2 -m 0.1) -r 7 " 139 | + "-d (moa.streams.generators.imbalanced.HyperplaneGenerator -i 8 -a 10 -c 2 -m 0.2) -r 8 " 140 | + "-p 125000 -w 50000) " 141 | + "-p 125000 -w 50000) " 142 | + "-p 125000 -w 50000) " 143 | + "-p 125000 -w 50000) " 144 | + "-p 125000 -w 50000) " 145 | + "-p 125000 -w 50000) " 146 | + "-p 125000 -w 50000", 147 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 10 -c 2 -m 0.2) -r 1 " 148 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 2 -r 2 -a 10 -c 2 -m 0.1) -r 2 " 149 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 3 -r 3 -a 10 -c 2 -m 0.05) -r 3 " 150 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 4 -r 4 -a 10 -c 2 -m 0.01) -r 4 " 151 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 5 -r 5 -a 10 -c 2 -m 0.01) -r 5 " 152 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 6 -r 6 -a 10 -c 2 -m 0.05) -r 6 " 153 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 7 -r 7 -a 10 -c 2 -m 0.1) -r 7 " 154 | + "-d (moa.streams.generators.imbalanced.RandomRBFGenerator -i 8 -r 8 -a 10 -c 2 -m 0.2) -r 8 " 155 | + "-p 125000 -w 50000) " 156 | + "-p 125000 -w 50000) " 157 | + "-p 125000 -w 50000) " 158 | + "-p 125000 -w 50000) " 159 | + "-p 125000 -w 50000) " 160 | + "-p 125000 -w 50000) " 161 | + "-p 125000 -w 50000", 162 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2) -r 1 " 163 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 2 -f 1 -m 0.1) -r 2 " 164 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 3 -f 1 -m 0.05) -r 3 " 165 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 4 -f 1 -m 0.01) -r 4 " 166 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 5 -f 1 -m 0.01) -r 5 " 167 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 6 -f 1 -m 0.05) -r 6 " 168 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 7 -f 1 -m 0.1) -r 7 " 169 | + "-d (moa.streams.generators.imbalanced.SEAGenerator -i 8 -f 1 -m 0.2) -r 8 " 170 | + "-p 125000 -w 50000) " 171 | + "-p 125000 -w 50000) " 172 | + "-p 125000 -w 50000) " 173 | + "-p 125000 -w 50000) " 174 | + "-p 125000 -w 50000) " 175 | + "-p 125000 -w 50000) " 176 | + "-p 125000 -w 50000", 177 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2) -r 1 " 178 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 2 -f 1 -m 0.1) -r 2 " 179 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 3 -f 1 -m 0.05) -r 3 " 180 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 4 -f 1 -m 0.01) -r 4 " 181 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 5 -f 1 -m 0.01) -r 5 " 182 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 6 -f 1 -m 0.05) -r 6 " 183 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 7 -f 1 -m 0.1) -r 7 " 184 | + "-d (moa.streams.generators.imbalanced.SineGenerator -i 8 -f 1 -m 0.2) -r 8 " 185 | + "-p 125000 -w 50000) " 186 | + "-p 125000 -w 50000) " 187 | + "-p 125000 -w 50000) " 188 | + "-p 125000 -w 50000) " 189 | + "-p 125000 -w 50000) " 190 | + "-p 125000 -w 50000) " 191 | + "-p 125000 -w 50000", 192 | }; 193 | 194 | String[] generatorsFilename = new String[] { 195 | "Sudden-AgrawalGenerator-IncreasingDecreasingIR", 196 | "Sudden-AssetNegotiationGenerator-IncreasingDecreasingIR", 197 | "Sudden-HyperplaneGenerator-IncreasingDecreasingIR", 198 | "Sudden-RandomRBFGenerator-IncreasingDecreasingIR", 199 | "Sudden-SEAGenerator-IncreasingDecreasingIR", 200 | "Sudden-SineGenerator-IncreasingDecreasingIR", 201 | "Gradual-AgrawalGenerator-IncreasingDecreasingIR", 202 | "Gradual-AssetNegotiationGenerator-IncreasingDecreasingIR", 203 | "Gradual-HyperplaneGenerator-IncreasingDecreasingIR", 204 | "Gradual-RandomRBFGenerator-IncreasingDecreasingIR", 205 | "Gradual-SEAGenerator-IncreasingDecreasingIR", 206 | "Gradual-SineGenerator-IncreasingDecreasingIR", 207 | }; 208 | 209 | String[] algorithms = new String[] { 210 | "moa.classifiers.meta.imbalanced.ROSE", 211 | "moa.classifiers.meta.KUE", 212 | "moa.classifiers.meta.AccuracyWeightedEnsemble", 213 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1", 214 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2", 215 | "moa.classifiers.meta.DynamicWeightedMajority", 216 | "moa.classifiers.meta.SAE2", 217 | "moa.classifiers.meta.DACC", 218 | "moa.classifiers.meta.ADACC", 219 | "moa.classifiers.meta.AdaptiveRandomForest", 220 | "moa.classifiers.meta.ADOB", 221 | "moa.classifiers.meta.BOLE", 222 | "moa.classifiers.meta.GOOWE", 223 | "moa.classifiers.meta.HeterogeneousEnsembleBlast", 224 | "moa.classifiers.meta.LeveragingBag", 225 | "moa.classifiers.meta.OCBoost", 226 | "moa.classifiers.meta.OzaBag", 227 | "moa.classifiers.meta.OzaBagAdwin", 228 | "moa.classifiers.meta.OzaBagASHT", 229 | "moa.classifiers.meta.OzaBoost", 230 | "moa.classifiers.meta.OzaBoostAdwin", 231 | "moa.classifiers.meta.StreamingRandomPatches", 232 | "moa.classifiers.meta.UOB", 233 | "moa.classifiers.meta.OOB", 234 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging", 235 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging", 236 | "moa.classifiers.meta.imbalanced.CSMOTE", 237 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost", 238 | "moa.classifiers.meta.imbalanced.OnlineAdaC2", 239 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost", 240 | "moa.classifiers.meta.imbalanced.RebalanceStream", 241 | }; 242 | 243 | String[] algorithmsFilename = new String[algorithms.length]; 244 | 245 | for(int alg = 0; alg < algorithms.length; alg++) { 246 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", ""); 247 | } 248 | 249 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";"; 250 | 251 | for(int gen = 0; gen < generators.length; gen++) 252 | { 253 | for(int alg = 0; alg < algorithms.length; alg++) 254 | { 255 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar " 256 | + "moa.DoTask EvaluateInterleavedTestThenTrain" 257 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator)\"" 258 | + " -s \"(" + generators[gen] + ")\"" 259 | + " -l \"(" + algorithms[alg] + ")\"" 260 | + " -f 500" 261 | + " -d results_drifting_IR/" + algorithmsFilename[alg] + "-" + generatorsFilename[gen] + ".csv"); 262 | } 263 | } 264 | } 265 | } -------------------------------------------------------------------------------- /src/main/java/experiments/Instance_Level_Difficulties.java: -------------------------------------------------------------------------------- 1 | package experiments; 2 | 3 | import org.apache.commons.lang3.SystemUtils; 4 | 5 | public class Instance_Level_Difficulties { 6 | 7 | public static void main(String[] args) throws Exception { 8 | 9 | // Download datasets from https://people.vcu.edu/~acano/ROSE/datasets-instance-level.zip 10 | 11 | // Splitting into 5 clusters and evaluating the impact of percentage of borderline and rare instances 12 | 13 | String[] datasets = new String[] { 14 | // IR 1 15 | "Split5", 16 | "Split5+Rare20", 17 | "Split5+Rare40", 18 | "Split5+Rare60", 19 | "Split5+Rare80", 20 | "Split5+Rare100", 21 | "Split5+Borderline20", 22 | "Split5+Borderline40", 23 | "Split5+Borderline60", 24 | "Split5+Borderline80", 25 | "Split5+Borderline100", 26 | "Split5+Borderline20+Rare20", 27 | "Split5+Borderline40+Rare40", 28 | 29 | // IR 10 30 | "Split5+Im10", 31 | "Split5+Im10+Rare20", 32 | "Split5+Im10+Rare40", 33 | "Split5+Im10+Rare60", 34 | "Split5+Im10+Rare80", 35 | "Split5+Im10+Rare100", 36 | "Split5+Im10+Borderline20", 37 | "Split5+Im10+Borderline40", 38 | "Split5+Im10+Borderline60", 39 | "Split5+Im10+Borderline80", 40 | "Split5+Im10+Borderline100", 41 | "Split5+Im10+Borderline20+Rare20", 42 | "Split5+Im10+Borderline40+Rare40", 43 | 44 | // IR 100 45 | "Split5+Im1", 46 | "Split5+Im1+Rare20", 47 | "Split5+Im1+Rare40", 48 | "Split5+Im1+Rare60", 49 | "Split5+Im1+Rare80", 50 | "Split5+Im1+Rare100", 51 | "Split5+Im1+Borderline20", 52 | "Split5+Im1+Borderline40", 53 | "Split5+Im1+Borderline60", 54 | "Split5+Im1+Borderline80", 55 | "Split5+Im1+Borderline100", 56 | "Split5+Im1+Borderline20+Rare20", 57 | "Split5+Im1+Borderline40+Rare40", 58 | }; 59 | 60 | String[] algorithms = new String[] { 61 | "moa.classifiers.meta.imbalanced.ROSE", 62 | "moa.classifiers.meta.KUE", 63 | "moa.classifiers.meta.AccuracyWeightedEnsemble", 64 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1", 65 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2", 66 | "moa.classifiers.meta.DynamicWeightedMajority", 67 | "moa.classifiers.meta.SAE2", 68 | "moa.classifiers.meta.DACC", 69 | "moa.classifiers.meta.ADACC", 70 | "moa.classifiers.meta.AdaptiveRandomForest", 71 | "moa.classifiers.meta.ADOB", 72 | "moa.classifiers.meta.BOLE", 73 | "moa.classifiers.meta.GOOWE", 74 | "moa.classifiers.meta.HeterogeneousEnsembleBlast", 75 | "moa.classifiers.meta.LeveragingBag", 76 | "moa.classifiers.meta.OCBoost", 77 | "moa.classifiers.meta.OzaBag", 78 | "moa.classifiers.meta.OzaBagAdwin", 79 | "moa.classifiers.meta.OzaBagASHT", 80 | "moa.classifiers.meta.OzaBoost", 81 | "moa.classifiers.meta.OzaBoostAdwin", 82 | "moa.classifiers.meta.StreamingRandomPatches", 83 | "moa.classifiers.meta.UOB", 84 | "moa.classifiers.meta.OOB", 85 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging", 86 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging", 87 | "moa.classifiers.meta.imbalanced.CSMOTE", 88 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost", 89 | "moa.classifiers.meta.imbalanced.OnlineAdaC2", 90 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost", 91 | "moa.classifiers.meta.imbalanced.RebalanceStream", 92 | }; 93 | 94 | String[] algorithmsFilename = new String[algorithms.length]; 95 | 96 | for(int alg = 0; alg < algorithms.length; alg++) { 97 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", ""); 98 | } 99 | 100 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";"; 101 | 102 | for(int dat = 0; dat < datasets.length; dat++) 103 | { 104 | for(int alg = 0; alg < algorithms.length; alg++) 105 | { 106 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar " 107 | + "moa.DoTask EvaluateInterleavedTestThenTrain" 108 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator -w 100)\"" 109 | + " -s \"(ArffFileStream -f datasets-instance-level-difficulties/" + datasets[dat] + ".arff)\"" 110 | + " -l \"(" + algorithms[alg] + ")\"" 111 | + " -f 500" 112 | + " -d results_instance_level_difficulties/" + algorithmsFilename[alg] + "-" + datasets[dat] + ".csv"); 113 | } 114 | } 115 | } 116 | } -------------------------------------------------------------------------------- /src/main/java/experiments/Static_Imbalance_Ratio.java: -------------------------------------------------------------------------------- 1 | package experiments; 2 | 3 | import org.apache.commons.lang3.SystemUtils; 4 | 5 | public class Static_Imbalance_Ratio { 6 | 7 | public static void main(String[] args) throws Exception { 8 | 9 | String[] generators = new String[] { 10 | // IR 100 11 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.01", 12 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.01", 13 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.01", 14 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.01", 15 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.01", 16 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.01", 17 | 18 | // IR 50 19 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.02", 20 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.02", 21 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.02", 22 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.02", 23 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.02", 24 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.02", 25 | 26 | // IR 20 27 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.05", 28 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.05", 29 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.05", 30 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.05", 31 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.05", 32 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.05", 33 | 34 | // IR 10 35 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.1", 36 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.1", 37 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.1", 38 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.1", 39 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.1", 40 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.1", 41 | 42 | // IR 5 43 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2", 44 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.2", 45 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.2", 46 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2", 47 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2", 48 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.2", 49 | }; 50 | 51 | String[] algorithms = new String[] { 52 | "moa.classifiers.meta.imbalanced.ROSE", 53 | "moa.classifiers.meta.KUE", 54 | "moa.classifiers.meta.AccuracyWeightedEnsemble", 55 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1", 56 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2", 57 | "moa.classifiers.meta.DynamicWeightedMajority", 58 | "moa.classifiers.meta.SAE2", 59 | "moa.classifiers.meta.DACC", 60 | "moa.classifiers.meta.ADACC", 61 | "moa.classifiers.meta.AdaptiveRandomForest", 62 | "moa.classifiers.meta.ADOB", 63 | "moa.classifiers.meta.BOLE", 64 | "moa.classifiers.meta.GOOWE", 65 | "moa.classifiers.meta.HeterogeneousEnsembleBlast", 66 | "moa.classifiers.meta.LeveragingBag", 67 | "moa.classifiers.meta.OCBoost", 68 | "moa.classifiers.meta.OzaBag", 69 | "moa.classifiers.meta.OzaBagAdwin", 70 | "moa.classifiers.meta.OzaBagASHT", 71 | "moa.classifiers.meta.OzaBoost", 72 | "moa.classifiers.meta.OzaBoostAdwin", 73 | "moa.classifiers.meta.StreamingRandomPatches", 74 | "moa.classifiers.meta.UOB", 75 | "moa.classifiers.meta.OOB", 76 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging", 77 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging", 78 | "moa.classifiers.meta.imbalanced.CSMOTE", 79 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost", 80 | "moa.classifiers.meta.imbalanced.OnlineAdaC2", 81 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost", 82 | "moa.classifiers.meta.imbalanced.RebalanceStream", 83 | }; 84 | 85 | String[] algorithmsFilename = new String[algorithms.length]; 86 | String[] generatorsFilename = new String[generators.length]; 87 | 88 | for(int alg = 0; alg < algorithms.length; alg++) { 89 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", ""); 90 | } 91 | 92 | for(int gen = 0; gen < generators.length; gen++) { 93 | generatorsFilename[gen] = generators[gen].replaceAll("moa.streams.generators.imbalanced.", "").replaceAll(" ", ""); 94 | } 95 | 96 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";"; 97 | 98 | for(int gen = 0; gen < generators.length; gen++) 99 | { 100 | for(int alg = 0; alg < algorithms.length; alg++) 101 | { 102 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar " 103 | + "moa.DoTask EvaluateInterleavedTestThenTrain" 104 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator)\"" 105 | + " -s \"(" + generators[gen] + ")\"" 106 | + " -l \"(" + algorithms[alg] + ")\"" 107 | + " -f 500" 108 | + " -d results_static_IR/" + algorithmsFilename[alg] + "-" + generatorsFilename[gen] + ".csv"); 109 | } 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /src/main/java/moa/classifiers/meta/imbalanced/ROSE.java: -------------------------------------------------------------------------------- 1 | package moa.classifiers.meta.imbalanced; 2 | 3 | import java.util.ArrayList; 4 | 5 | import com.github.javacliparser.FloatOption; 6 | import com.github.javacliparser.IntOption; 7 | import com.yahoo.labs.samoa.instances.Instance; 8 | import com.yahoo.labs.samoa.instances.Instances; 9 | import com.yahoo.labs.samoa.instances.InstancesHeader; 10 | 11 | import moa.AbstractMOAObject; 12 | import moa.capabilities.CapabilitiesHandler; 13 | import moa.capabilities.Capability; 14 | import moa.capabilities.ImmutableCapabilities; 15 | import moa.classifiers.AbstractClassifier; 16 | import moa.classifiers.MultiClassClassifier; 17 | import moa.classifiers.core.driftdetection.ChangeDetector; 18 | import moa.classifiers.trees.RandomSubspaceHT; 19 | import moa.core.DoubleVector; 20 | import moa.core.InstanceExample; 21 | import moa.core.Measurement; 22 | import moa.core.MiscUtils; 23 | import moa.core.Utils; 24 | import moa.evaluation.WindowImbalancedClassificationPerformanceEvaluator; 25 | import moa.options.ClassOption; 26 | 27 | /** 28 | * ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams 29 | * 30 | * @author Alberto Cano 31 | */ 32 | public class ROSE extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler { 33 | 34 | private static final long serialVersionUID = 1L; 35 | 36 | public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "RandomSubspaceHT", RandomSubspaceHT.class, "RandomSubspaceHT"); 37 | 38 | public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 10, 1, Integer.MAX_VALUE); 39 | 40 | public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0, 1.0, Float.MAX_VALUE); 41 | 42 | public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5"); 43 | 44 | public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4"); 45 | 46 | public IntOption featureSpaceOption = new IntOption("feature", 'f', "The feature space to employ {1: uniform distribution with at last 50% of features; 2: normal distribution based on percentageFeaturesMean (70% default)}.", 1, 1, 2); 47 | 48 | public FloatOption percentageFeaturesMean = new FloatOption("percentageFeaturesMean", 'm', "Mean for percentage of featues selected", 0.7, 0, 1); 49 | 50 | public FloatOption theta = new FloatOption("theta", 't', "The time decay factor for class size.", 0.99, 0, 1); 51 | 52 | public IntOption windowSizeOption = new IntOption("window", 'w', "The number of instances in the sliding window.", 500, 1, Integer.MAX_VALUE); 53 | 54 | protected ROSEBaseLearner[] ensemble; 55 | protected ROSEBaseLearner[] ensembleBackground; 56 | protected InstancesTimestap[] instancesClass; 57 | 58 | protected double classSize[]; 59 | protected long instancesSeen; 60 | protected long firstWarningOn; 61 | protected boolean warningDetected; 62 | protected WindowImbalancedClassificationPerformanceEvaluator evaluator; 63 | 64 | @Override 65 | public void resetLearningImpl() { 66 | this.warningDetected = true; 67 | this.firstWarningOn = 0; 68 | this.instancesClass = null; 69 | this.ensemble = null; 70 | this.ensembleBackground = null; 71 | this.classSize = null; 72 | this.instancesSeen = 0; 73 | this.evaluator = new WindowImbalancedClassificationPerformanceEvaluator(); 74 | } 75 | 76 | @Override 77 | public void trainOnInstanceImpl(Instance instance) { 78 | this.instancesSeen++; 79 | 80 | this.instancesClass[(int) instance.classValue()].add(instance, instancesSeen); 81 | 82 | if(this.ensemble == null) { 83 | initEnsemble(instance); 84 | } 85 | 86 | for (int i=0; i < classSize.length; i++) { 87 | classSize[i] = theta.getValue() * classSize[i] + (1d - theta.getValue()) * ((int) instance.classValue() == i ? 1d:0d); 88 | } 89 | 90 | double lambda = lambdaOption.getValue() + lambdaOption.getValue() * Math.log(classSize[Utils.maxIndex(classSize)] / classSize[(int) instance.classValue()]); 91 | 92 | for (int i = 0; i < this.ensemble.length; i++) { 93 | DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance)); 94 | InstanceExample example = new InstanceExample(instance); 95 | this.ensemble[i].evaluator.addResult(example, vote.getArrayRef()); 96 | 97 | int k = MiscUtils.poisson(lambda, this.classifierRandom); 98 | if (k > 0) { 99 | this.ensemble[i].trainOnInstance(instance, k, this.instancesSeen); 100 | } 101 | } 102 | 103 | if(!this.warningDetected) { 104 | for (int i = 0; i < this.ensemble.length; i++) { 105 | if(this.ensemble[i].warningDetected) { 106 | this.warningDetected = true; 107 | this.firstWarningOn = instancesSeen; 108 | break; 109 | } 110 | } 111 | 112 | if(this.warningDetected) { 113 | int numberAttributes = instance.numAttributes() - 1; 114 | 115 | RandomSubspaceHT treeLearner = (RandomSubspaceHT) getPreparedClassOption(this.treeLearnerOption); 116 | treeLearner.resetLearning(); 117 | 118 | WindowImbalancedClassificationPerformanceEvaluator classificationEvaluator = new WindowImbalancedClassificationPerformanceEvaluator(); 119 | 120 | for(int i = 0; i < this.ensembleSizeOption.getValue(); i++) { 121 | 122 | int subspaceSize = -1; 123 | 124 | if(featureSpaceOption.getValue() == 1) { 125 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2)); 126 | } else if(featureSpaceOption.getValue() == 2) { 127 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5); 128 | } 129 | 130 | if (subspaceSize > numberAttributes) { 131 | subspaceSize = numberAttributes; 132 | } else if (subspaceSize <= 0) { 133 | subspaceSize = 1; 134 | } 135 | 136 | treeLearner.subspaceSizeOption.setValue(subspaceSize); 137 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE)); 138 | 139 | this.ensembleBackground[i] = new ROSEBaseLearner( 140 | (RandomSubspaceHT) treeLearner.copy(), 141 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(), 142 | this.instancesSeen, 143 | driftDetectionMethodOption, 144 | warningDetectionMethodOption, 145 | false); 146 | } 147 | 148 | int[] indexClass = new int[instance.numClasses()]; 149 | Long[] oldestTimestamps = new Long[instance.numClasses()]; 150 | 151 | do { 152 | Long oldestTimestamp = Long.MAX_VALUE; 153 | int nextClass = -1; 154 | 155 | for(int c = 0; c < instance.numClasses(); c++) { 156 | oldestTimestamps[c] = this.instancesClass[c].getTimestamp(indexClass[c]); 157 | 158 | if(oldestTimestamps[c] != null && oldestTimestamps[c] < oldestTimestamp) { 159 | oldestTimestamp = oldestTimestamps[c]; 160 | nextClass = c; 161 | } 162 | } 163 | 164 | if(nextClass == -1) { 165 | break; 166 | } 167 | 168 | Instance windowInstance = this.instancesClass[nextClass].getInstance(indexClass[nextClass]); 169 | 170 | for (int i = 0; i < this.ensembleBackground.length; i++) { 171 | int k = MiscUtils.poisson(lambdaOption.getValue(), this.classifierRandom); 172 | if (k > 0) { 173 | this.ensembleBackground[i].trainOnInstance(windowInstance, k, this.instancesSeen); 174 | } 175 | } 176 | 177 | indexClass[nextClass]++; 178 | 179 | } while(true); 180 | } 181 | } 182 | 183 | if(this.warningDetected) { 184 | for (int i = 0; i < this.ensembleBackground.length; i++) { 185 | DoubleVector vote = new DoubleVector(this.ensembleBackground[i].getVotesForInstance(instance)); 186 | InstanceExample example = new InstanceExample(instance); 187 | this.ensembleBackground[i].evaluator.addResult(example, vote.getArrayRef()); 188 | 189 | int k = MiscUtils.poisson(lambda, this.classifierRandom); 190 | if (k > 0) { 191 | this.ensembleBackground[i].trainOnInstance(instance, k, this.instancesSeen); 192 | } 193 | } 194 | 195 | if(this.instancesSeen - this.firstWarningOn == this.evaluator.widthOption.getValue()) { 196 | // Compare the ensemble and the background ensemble. Select the best components 197 | 198 | ArrayList classifiers = new ArrayList(); 199 | ArrayList selection = new ArrayList(); 200 | ArrayList kappas = new ArrayList(); 201 | ArrayList accuracies = new ArrayList(); 202 | 203 | for (int i = 0; i < this.ensemble.length; i++) { 204 | classifiers.add(this.ensemble[i]); 205 | kappas.add(this.ensemble[i].evaluator.getKappa()); 206 | accuracies.add(this.ensemble[i].evaluator.getAccuracy()); 207 | } 208 | 209 | for (int i = 0; i < this.ensembleBackground.length; i++) { 210 | classifiers.add(this.ensembleBackground[i]); 211 | kappas.add(this.ensembleBackground[i].evaluator.getKappa()); 212 | accuracies.add(this.ensembleBackground[i].evaluator.getAccuracy()); 213 | } 214 | 215 | for (int i = 0; i < this.ensemble.length; i++) { 216 | 217 | double maxKappaAccuracy = -1; 218 | int maxKappaAccuracyClassifier = -1; 219 | 220 | for (int j = 0; j < this.ensemble.length + this.ensembleBackground.length - i; j++) { 221 | if(kappas.get(j) * accuracies.get(j) >= maxKappaAccuracy) { 222 | maxKappaAccuracy = kappas.get(j) * accuracies.get(j); 223 | maxKappaAccuracyClassifier = j; 224 | } 225 | } 226 | 227 | selection.add(classifiers.get(maxKappaAccuracyClassifier)); 228 | 229 | classifiers.remove(maxKappaAccuracyClassifier); 230 | kappas.remove(maxKappaAccuracyClassifier); 231 | accuracies.remove(maxKappaAccuracyClassifier); 232 | } 233 | 234 | for (int i = 0; i < this.ensemble.length; i++) { 235 | this.ensemble[i] = selection.get(i); 236 | } 237 | 238 | for (int i = 0; i < this.ensembleBackground.length; i++) { 239 | this.ensembleBackground[i] = null; 240 | } 241 | 242 | this.warningDetected = false; 243 | } 244 | } 245 | } 246 | 247 | @Override 248 | public double[] getVotesForInstance(Instance instance) { 249 | if(this.ensemble == null) { 250 | initEnsemble(instance); 251 | } 252 | 253 | DoubleVector combinedVote = new DoubleVector(); 254 | DoubleVector combinedVoteUnweighted = new DoubleVector(); 255 | 256 | for(int i = 0; i < this.ensemble.length; i++) { 257 | DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance)); 258 | if (vote.sumOfValues() > 0.0) { 259 | vote.normalize(); 260 | 261 | combinedVoteUnweighted.addValues(vote); 262 | 263 | double kappa = this.ensemble[i].evaluator.getKappa(); 264 | double accuracy = this.ensemble[i].evaluator.getAccuracy(); 265 | 266 | if(kappa > 0.0) { 267 | vote.scaleValues(kappa * accuracy); 268 | combinedVote.addValues(vote); 269 | } 270 | } 271 | } 272 | 273 | if(combinedVote.sumOfValues() == 0) { 274 | return combinedVoteUnweighted.getArrayRef(); 275 | } else { 276 | return combinedVote.getArrayRef(); 277 | } 278 | } 279 | 280 | @Override 281 | public boolean isRandomizable() { 282 | return true; 283 | } 284 | 285 | @Override 286 | public void getModelDescription(StringBuilder arg0, int arg1) { 287 | } 288 | 289 | @Override 290 | protected Measurement[] getModelMeasurementsImpl() { 291 | return null; 292 | } 293 | 294 | protected void initEnsemble(Instance instance) { 295 | 296 | this.instancesClass = new InstancesTimestap[instance.numClasses()]; 297 | 298 | for(int i = 0; i < instance.numClasses(); i++) { 299 | this.instancesClass[i] = new InstancesTimestap(this.getModelContext()); 300 | } 301 | 302 | classSize = new double[instance.numClasses()]; 303 | 304 | for (int i=0; i < classSize.length; i++) { 305 | classSize[i] = 1d / classSize.length; 306 | } 307 | 308 | int ensembleSize = this.ensembleSizeOption.getValue(); 309 | int numberAttributes = instance.numAttributes() - 1; 310 | 311 | this.ensemble = new ROSEBaseLearner[ensembleSize]; 312 | this.ensembleBackground = new ROSEBaseLearner[ensembleSize]; 313 | 314 | WindowImbalancedClassificationPerformanceEvaluator classificationEvaluator = new WindowImbalancedClassificationPerformanceEvaluator(); 315 | RandomSubspaceHT treeLearner = (RandomSubspaceHT) getPreparedClassOption(this.treeLearnerOption); 316 | treeLearner.resetLearning(); 317 | 318 | // Primary ensemble 319 | for(int i = 0; i < ensembleSize; i++) { 320 | int subspaceSize = -1; 321 | 322 | if(featureSpaceOption.getValue() == 1) { 323 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2)); 324 | } else if(featureSpaceOption.getValue() == 2) { 325 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5); 326 | } 327 | 328 | if (subspaceSize > numberAttributes) { 329 | subspaceSize = numberAttributes; 330 | } else if (subspaceSize <= 0) { 331 | subspaceSize = 1; 332 | } 333 | 334 | treeLearner.subspaceSizeOption.setValue(subspaceSize); 335 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE)); 336 | 337 | this.ensemble[i] = new ROSEBaseLearner( 338 | (RandomSubspaceHT) treeLearner.copy(), 339 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(), 340 | this.instancesSeen, 341 | driftDetectionMethodOption, 342 | warningDetectionMethodOption, 343 | false); 344 | } 345 | 346 | // Background ensemble 347 | for(int i = 0; i < this.ensembleSizeOption.getValue(); i++) { 348 | int subspaceSize = -1; 349 | 350 | if(featureSpaceOption.getValue() == 1) { 351 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2)); 352 | } else if(featureSpaceOption.getValue() == 2) { 353 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5); 354 | } 355 | 356 | if (subspaceSize > numberAttributes) { 357 | subspaceSize = numberAttributes; 358 | } else if (subspaceSize <= 0) { 359 | subspaceSize = 1; 360 | } 361 | 362 | treeLearner.subspaceSizeOption.setValue(subspaceSize); 363 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE)); 364 | 365 | this.ensembleBackground[i] = new ROSEBaseLearner( (RandomSubspaceHT) treeLearner.copy(), 366 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(), 367 | this.instancesSeen, 368 | driftDetectionMethodOption, 369 | warningDetectionMethodOption, 370 | false); 371 | } 372 | } 373 | 374 | @Override 375 | public ImmutableCapabilities defineImmutableCapabilities() { 376 | if (this.getClass() == ROSE.class) 377 | return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE); 378 | else 379 | return new ImmutableCapabilities(Capability.VIEW_STANDARD); 380 | } 381 | 382 | @Override 383 | public String getPurposeString() { 384 | return "ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams"; 385 | } 386 | 387 | /** 388 | * Inner class that represents a single tree member of the forest. 389 | * It contains some analysis information, such as the numberOfDriftsDetected, 390 | */ 391 | protected final class ROSEBaseLearner extends AbstractMOAObject { 392 | private static final long serialVersionUID = -3758478930527651262L; 393 | public long createdOn; 394 | public boolean warningDetected; 395 | public long lastDriftOn; 396 | public long lastWarningOn; 397 | public RandomSubspaceHT classifier; 398 | public boolean isBackgroundLearner; 399 | 400 | // The drift and warning object parameters. 401 | protected ClassOption driftOption; 402 | protected ClassOption warningOption; 403 | 404 | // Drift and warning detection 405 | protected ChangeDetector driftDetectionMethod; 406 | protected ChangeDetector warningDetectionMethod; 407 | 408 | // Bkg learner 409 | protected ROSEBaseLearner bkgLearner; 410 | // Statistics 411 | public WindowImbalancedClassificationPerformanceEvaluator evaluator; 412 | protected int numberOfDriftsDetected; 413 | protected int numberOfWarningsDetected; 414 | 415 | public ROSEBaseLearner (RandomSubspaceHT instantiatedClassifier, WindowImbalancedClassificationPerformanceEvaluator evaluatorInstantiated, long instancesSeen, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) { 416 | this.createdOn = instancesSeen; 417 | this.lastDriftOn = 0; 418 | this.lastWarningOn = 0; 419 | this.warningDetected = false; 420 | 421 | this.classifier = instantiatedClassifier; 422 | this.evaluator = evaluatorInstantiated; 423 | 424 | this.numberOfDriftsDetected = 0; 425 | this.numberOfWarningsDetected = 0; 426 | this.isBackgroundLearner = isBackgroundLearner; 427 | 428 | this.driftOption = driftOption; 429 | this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy(); 430 | 431 | this.warningOption = warningOption; 432 | this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy(); 433 | } 434 | 435 | public void reset() { 436 | if(this.bkgLearner != null) { 437 | this.classifier = this.bkgLearner.classifier; 438 | 439 | this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod; 440 | this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod; 441 | 442 | this.evaluator = this.bkgLearner.evaluator; 443 | this.createdOn = this.bkgLearner.createdOn; 444 | this.bkgLearner = null; 445 | } 446 | else { 447 | this.classifier.resetLearning(); 448 | this.createdOn = instancesSeen; 449 | this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy(); 450 | } 451 | 452 | this.lastWarningOn = 0; 453 | this.lastDriftOn = 0; 454 | this.warningDetected = false; 455 | this.evaluator.reset(); 456 | } 457 | 458 | public void trainOnInstance(Instance instance, double weight, long instancesSeen) { 459 | Instance weightedInstance = (Instance) instance.copy(); 460 | weightedInstance.setWeight(instance.weight() * weight); 461 | this.classifier.trainOnInstance(weightedInstance); 462 | 463 | if(this.bkgLearner != null) 464 | this.bkgLearner.classifier.trainOnInstance(weightedInstance); 465 | 466 | // Should it use a drift detector? Also, is it a backgroundLearner? If so, then do not "incept" another one. 467 | if(!this.isBackgroundLearner) { 468 | boolean correctlyClassifies = this.classifier.correctlyClassifies(instance); 469 | // Check for warning only if useBkgLearner is active 470 | // Update the warning detection method 471 | this.warningDetectionMethod.input(correctlyClassifies ? 0 : 1); 472 | // Check if there was a change 473 | if(this.warningDetectionMethod.getChange()) { 474 | this.warningDetected = true; 475 | this.lastWarningOn = instancesSeen; 476 | this.numberOfWarningsDetected++; 477 | // Create a new bkgTree classifier 478 | RandomSubspaceHT bkgClassifier = (RandomSubspaceHT) this.classifier.copy(); 479 | bkgClassifier.resetLearning(); 480 | bkgClassifier.setup(this.classifier.listAttributes, this.classifier.instanceHeader); 481 | 482 | // Resets the evaluator 483 | WindowImbalancedClassificationPerformanceEvaluator bkgEvaluator = (WindowImbalancedClassificationPerformanceEvaluator) this.evaluator.copy(); 484 | bkgEvaluator.reset(); 485 | 486 | // Create a new bkgLearner object 487 | this.bkgLearner = new ROSEBaseLearner(bkgClassifier, bkgEvaluator, instancesSeen, this.driftOption, this.warningOption, true); 488 | 489 | // Update the warning detection object for the current object 490 | // (this effectively resets changes made to the object while it was still a bkg learner). 491 | this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy(); 492 | } else { 493 | this.warningDetected = false; 494 | } 495 | 496 | /*********** drift detection ***********/ 497 | 498 | // Update the DRIFT detection method 499 | this.driftDetectionMethod.input(correctlyClassifies ? 0 : 1); 500 | // Check if there was a change 501 | if(this.driftDetectionMethod.getChange()) { 502 | this.lastDriftOn = instancesSeen; 503 | this.numberOfDriftsDetected++; 504 | this.reset(); 505 | } 506 | } 507 | } 508 | 509 | public double[] getVotesForInstance(Instance instance) { 510 | DoubleVector vote = new DoubleVector(this.classifier.getVotesForInstance(instance)); 511 | return vote.getArrayRef(); 512 | } 513 | 514 | @Override 515 | public void getDescription(StringBuilder sb, int indent) { 516 | } 517 | } 518 | 519 | private class InstancesTimestap { 520 | private Instances instances; 521 | private ArrayList timestamps; 522 | 523 | public InstancesTimestap(InstancesHeader header) { 524 | instances = new Instances(header); 525 | timestamps = new ArrayList(); 526 | } 527 | 528 | public void add(Instance instance, long timestamp) { 529 | instances.add(instance); 530 | timestamps.add(timestamp); 531 | 532 | if (instances.size() > windowSizeOption.getValue() / instance.numClasses()) { 533 | instances.delete(0); 534 | timestamps.remove(0); 535 | } 536 | } 537 | 538 | public Instance getInstance(int index) { 539 | if(index >= instances.size()) 540 | return null; 541 | else 542 | return instances.get(index); 543 | } 544 | 545 | public Long getTimestamp(int index) { 546 | if(index >= timestamps.size()) 547 | return null; 548 | else 549 | return timestamps.get(index); 550 | } 551 | } 552 | } -------------------------------------------------------------------------------- /src/main/java/moa/classifiers/trees/RandomSubspaceHT.java: -------------------------------------------------------------------------------- 1 | package moa.classifiers.trees; 2 | 3 | import java.util.ArrayList; 4 | 5 | import com.github.javacliparser.IntOption; 6 | import com.yahoo.labs.samoa.instances.Instance; 7 | import com.yahoo.labs.samoa.instances.Instances; 8 | 9 | /** 10 | * Hoeffding Tree on a fixed random subspace of features 11 | * 12 | * @author Alberto Cano 13 | */ 14 | public class RandomSubspaceHT extends HoeffdingTree { 15 | 16 | private static final long serialVersionUID = 1L; 17 | 18 | public IntOption subspaceSizeOption = new IntOption("subspaceFeaturesSize", 'k', "Number of features", 1, 1, Integer.MAX_VALUE); 19 | 20 | public boolean[] listAttributes; 21 | 22 | public Instances instanceHeader; 23 | 24 | public void setup(boolean[] listAttibutes, Instances instanceHeader) { 25 | this.listAttributes = listAttibutes; 26 | this.instanceHeader = instanceHeader; 27 | } 28 | 29 | @Override 30 | public String getPurposeString() { 31 | return "ROSE Hoeffding Tree for data streams"; 32 | } 33 | 34 | @Override 35 | public void resetLearningImpl() { 36 | super.resetLearningImpl(); 37 | } 38 | 39 | @Override 40 | public void trainOnInstanceImpl(Instance instance) { 41 | if(this.listAttributes == null) { 42 | setupListAttributes(instance); 43 | } 44 | 45 | Instance instanceProjected = instance.copy(); 46 | 47 | for(int att = instanceProjected.numAttributes()-2; att >= 0; att--) { 48 | if(this.listAttributes[att] == false) { 49 | instanceProjected.deleteAttributeAt(att); 50 | } 51 | } 52 | 53 | instanceProjected.setDataset(instanceHeader); 54 | 55 | super.trainOnInstanceImpl(instanceProjected); 56 | } 57 | 58 | @Override 59 | public double[] getVotesForInstance(Instance instance) { 60 | if(this.listAttributes == null) { 61 | setupListAttributes(instance); 62 | } 63 | 64 | Instance instanceProjected = instance.copy(); 65 | 66 | for(int att = instanceProjected.numAttributes()-2; att >= 0; att--) { 67 | if(this.listAttributes[att] == false) { 68 | instanceProjected.deleteAttributeAt(att); 69 | } 70 | } 71 | 72 | instanceProjected.setDataset(instanceHeader); 73 | 74 | return super.getVotesForInstance(instanceProjected); 75 | } 76 | 77 | @Override 78 | public boolean isRandomizable() { 79 | return true; 80 | } 81 | 82 | protected void setupListAttributes(Instance instance) { 83 | int numberAttributes = instance.numAttributes() - 1; 84 | int subspaceSize = subspaceSizeOption.getValue(); 85 | 86 | this.listAttributes = new boolean[numberAttributes]; 87 | this.instanceHeader = new Instances(instance.dataset()); 88 | 89 | ArrayList attributesPool = new ArrayList(); 90 | 91 | for(int i = 0; i < numberAttributes; i++) { 92 | attributesPool.add(i); 93 | } 94 | 95 | for(int i = 0; i < subspaceSize; i++) { 96 | this.listAttributes[attributesPool.remove(this.classifierRandom.nextInt(attributesPool.size()))] = true; 97 | } 98 | 99 | for(int att = numberAttributes-1; att >= 0; att--) { 100 | if(this.listAttributes[att] == false) { 101 | this.instanceHeader.deleteAttributeAt(att); 102 | } 103 | } 104 | } 105 | } -------------------------------------------------------------------------------- /src/main/java/moa/evaluation/WindowAUCImbalancedPerformanceEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * ImbalancedPerformanceEvaluator.java 3 | * Copyright (C) 2016 Poznan University of Technology 4 | * @author Dariusz Brzezinski (dbrzezinski@cs.put.poznan.pl) 5 | * @author Tomasz Pewinski 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | package moa.evaluation; 20 | 21 | import java.util.TreeSet; 22 | 23 | import moa.core.Example; 24 | import moa.core.Measurement; 25 | import moa.core.ObjectRepository; 26 | import moa.core.Utils; 27 | import moa.options.AbstractOptionHandler; 28 | 29 | import com.github.javacliparser.IntOption; 30 | import com.yahoo.labs.samoa.instances.Instance; 31 | import com.yahoo.labs.samoa.instances.InstanceImpl; 32 | import com.yahoo.labs.samoa.instances.Prediction; 33 | 34 | import moa.tasks.TaskMonitor; 35 | 36 | /** 37 | * Classification evaluator that updates evaluation results using a sliding 38 | * window. Performance measures designed for class imbalance problems. 39 | * Only to be used for binary classification problems with unweighted instances. 40 | * Class 0 - majority/negative examples, class 1 - minority, positive examples. 41 | * Prequential AUC calculation as described and analyzed in: D. Brzezinski, 42 | * J. Stefanowski, "Prequential AUC: Properties of the Area Under the ROC 43 | * Curve for Data Streams with Concept Drift", Knowledge and Information 44 | * Systems, 2017. 45 | * 46 | * @author Dariusz Brzezinski (dbrzezinski at cs.put.poznan.pl) 47 | * @author Tomasz Pewinski 48 | */ 49 | public class WindowAUCImbalancedPerformanceEvaluator extends 50 | AbstractOptionHandler implements ClassificationPerformanceEvaluator { 51 | 52 | private static final long serialVersionUID = 1L; 53 | 54 | public IntOption widthOption = new IntOption("width", 'w', 55 | "Size of Window", 1000); 56 | 57 | protected double totalObservedInstances = 0; 58 | private Estimator aucEstimator; 59 | private SimpleEstimator weightMajorityClassifier; 60 | protected int numClasses; 61 | 62 | public class SimpleEstimator { 63 | protected double len; 64 | 65 | protected double sum; 66 | 67 | public void add(double value) { 68 | sum += value; 69 | len++; 70 | } 71 | 72 | public double estimation(){ 73 | return sum/len; 74 | } 75 | } 76 | 77 | public class Estimator { 78 | 79 | public class Score implements Comparable { 80 | /** 81 | * Predicted score of the example 82 | */ 83 | protected double value; 84 | 85 | /** 86 | * Age of example - position in the window where the example was 87 | * added 88 | */ 89 | protected int posWindow; 90 | 91 | /** 92 | * True if example's true label is positive 93 | */ 94 | protected boolean isPositive; 95 | 96 | /** 97 | * Constructor. 98 | * 99 | * @param value 100 | * score value 101 | * @param position 102 | * score position in window (defines its age) 103 | * @param isPositive 104 | * true if the example's true label is positive 105 | */ 106 | public Score(double value, int position, boolean isPositive) { 107 | this.value = value; 108 | this.posWindow = position; 109 | this.isPositive = isPositive; 110 | } 111 | 112 | /** 113 | * Sort descending based on score value. 114 | */ 115 | @Override 116 | public int compareTo(Score o) { 117 | if (o.value < this.value) { 118 | return -1; 119 | } else if (o.value > this.value){ 120 | return 1; 121 | } else { 122 | if (!o.isPositive && this.isPositive) { 123 | return -1; 124 | } else if (o.isPositive && !this.isPositive){ 125 | return 1; 126 | } else { 127 | if (o.posWindow > this.posWindow) { 128 | return -1; 129 | } else if (o.posWindow < this.posWindow){ 130 | return 1; 131 | } else { 132 | return 0; 133 | } 134 | } 135 | } 136 | } 137 | 138 | @Override 139 | public boolean equals(Object o) { 140 | return (o instanceof Score) && ((Score)o).posWindow == this.posWindow; 141 | } 142 | } 143 | 144 | protected TreeSet sortedScores; 145 | 146 | protected TreeSet holdoutSortedScores; 147 | 148 | protected Score[] window; 149 | 150 | protected double[] predictions; 151 | 152 | protected int posWindow; 153 | 154 | protected int size; 155 | 156 | protected double numPos; 157 | 158 | protected double numNeg; 159 | 160 | protected double holdoutNumPos; 161 | 162 | protected double holdoutNumNeg; 163 | 164 | protected double correctPredictions; 165 | 166 | protected double correctPositivePredictions; 167 | 168 | protected double[] columnKappa; 169 | 170 | protected double[] rowKappa; 171 | 172 | protected int confusionMatrix[][]; 173 | 174 | public Estimator(int sizeWindow) { 175 | this.sortedScores = new TreeSet(); 176 | this.holdoutSortedScores = new TreeSet(); 177 | this.size = sizeWindow; 178 | this.window = new Score[sizeWindow]; 179 | this.predictions = new double[sizeWindow]; 180 | 181 | this.rowKappa = new double[numClasses]; 182 | this.columnKappa = new double[numClasses]; 183 | this.confusionMatrix = new int[numClasses][numClasses]; 184 | 185 | for (int i = 0; i < numClasses; i++) { 186 | this.rowKappa[i] = 0.0; 187 | this.columnKappa[i] = 0.0; 188 | } 189 | 190 | this.posWindow = 0; 191 | this.numPos = 0; 192 | this.numNeg = 0; 193 | this.holdoutNumPos = 0; 194 | this.holdoutNumNeg = 0; 195 | this.correctPredictions = 0; 196 | this.correctPositivePredictions = 0; 197 | } 198 | 199 | public void add(double score, boolean isPositive, boolean correctPrediction) { 200 | // // periodically update holdout evaluation 201 | if (size > 0 && posWindow % this.size == 0) { 202 | this.holdoutSortedScores = new TreeSet(); 203 | 204 | for (Score s : this.sortedScores) { 205 | this.holdoutSortedScores.add(s); 206 | } 207 | 208 | this.holdoutNumPos = this.numPos; 209 | this.holdoutNumNeg = this.numNeg; 210 | } 211 | 212 | // // if the window is used and it's full 213 | if (size > 0 && posWindow >= this.size) { 214 | // // remove the oldest example 215 | sortedScores.remove(window[posWindow % size]); 216 | correctPredictions -= predictions[posWindow % size]; 217 | correctPositivePredictions -= window[posWindow % size].isPositive ? predictions[posWindow % size] : 0; 218 | 219 | if (window[posWindow % size].isPositive) { 220 | numPos--; 221 | } else { 222 | numNeg--; 223 | } 224 | 225 | int oldestExampleTrueClass = window[posWindow % size].isPositive ? 1 : 0; 226 | int oldestExamplePredictedClass = predictions[posWindow % size] == 1.0 ? oldestExampleTrueClass : Math.abs(oldestExampleTrueClass - 1); 227 | 228 | this.rowKappa[oldestExamplePredictedClass] -= 1; 229 | this.columnKappa[oldestExampleTrueClass] -= 1; 230 | this.confusionMatrix[oldestExamplePredictedClass][oldestExampleTrueClass]--; 231 | } 232 | 233 | // // add new example 234 | Score newScore = new Score(score, posWindow, isPositive); 235 | sortedScores.add(newScore); 236 | correctPredictions += correctPrediction ? 1 : 0; 237 | correctPositivePredictions += correctPrediction && isPositive ? 1 : 0; 238 | 239 | int trueClass = isPositive ? 1 : 0; 240 | int predictedClass = correctPrediction ? trueClass : Math.abs(trueClass - 1); 241 | this.rowKappa[predictedClass] += 1; 242 | this.columnKappa[trueClass] += 1; 243 | this.confusionMatrix[predictedClass][trueClass]++; 244 | 245 | 246 | if (newScore.isPositive) { 247 | numPos++; 248 | } else { 249 | numNeg++; 250 | } 251 | 252 | if (size > 0) { 253 | window[posWindow % size] = newScore; 254 | predictions[posWindow % size] = correctPrediction ? 1 : 0; 255 | } 256 | 257 | //// posWindow needs to be always incremented to differentiate between examples in the red-black tree 258 | posWindow++; 259 | } 260 | 261 | public double getAUC() { 262 | double AUC = 0; 263 | double c = 0; 264 | double prevc = 0; 265 | double lastPosScore = Double.MAX_VALUE; 266 | 267 | if (numPos == 0 || numNeg == 0) { 268 | return 1; 269 | } 270 | 271 | for (Score s : sortedScores){ 272 | if(s.isPositive) { 273 | if (s.value != lastPosScore) { 274 | prevc = c; 275 | lastPosScore = s.value; 276 | } 277 | 278 | c += 1; 279 | } else { 280 | if (s.value == lastPosScore) { 281 | // tie 282 | AUC += ((double)(c + prevc))/2.0; 283 | } else { 284 | AUC += c; 285 | } 286 | } 287 | } 288 | 289 | return AUC / (numPos * numNeg); 290 | } 291 | 292 | public double getHoldoutAUC() { 293 | double AUC = 0; 294 | double c = 0; 295 | double prevc = 0; 296 | double lastPosScore = Double.MAX_VALUE; 297 | 298 | if (holdoutSortedScores.isEmpty()) { 299 | return 0; 300 | } 301 | 302 | if (holdoutNumPos == 0 || holdoutNumNeg == 0) { 303 | return 1; 304 | } 305 | 306 | for (Score s : holdoutSortedScores){ 307 | if(s.isPositive) { 308 | if (s.value != lastPosScore) { 309 | prevc = c; 310 | lastPosScore = s.value; 311 | } 312 | 313 | c += 1; 314 | } else { 315 | if (s.value == lastPosScore) { 316 | // tie 317 | AUC += ((double)(c + prevc))/2.0; 318 | } else { 319 | AUC += c; 320 | } 321 | } 322 | } 323 | 324 | return AUC / (holdoutNumPos * holdoutNumNeg); 325 | } 326 | 327 | public double getScoredAUC() { 328 | double AOC = 0; 329 | double AUC = 0; 330 | double r = 0; 331 | double prevr = 0; 332 | double c = 0; 333 | double prevc = 0; 334 | double R_plus, R_minus; 335 | double lastPosScore = Double.MAX_VALUE; 336 | double lastNegScore = Double.MAX_VALUE; 337 | 338 | if (numPos == 0 || numNeg == 0) { 339 | return 1; 340 | } 341 | 342 | for (Score s : sortedScores){ 343 | if(s.isPositive) { 344 | if (s.value != lastPosScore) { 345 | prevc = c; 346 | lastPosScore = s.value; 347 | } 348 | 349 | c += s.value; 350 | 351 | if (s.value == lastNegScore) { 352 | // tie 353 | AOC += ((double)(r + prevr))/2.0; 354 | } else { 355 | AOC += r; 356 | } 357 | } else { 358 | if (s.value != lastNegScore) { 359 | prevr = r; 360 | lastNegScore = s.value; 361 | } 362 | 363 | r += s.value; 364 | 365 | if (s.value == lastPosScore) { 366 | // tie 367 | AUC += ((double)(c + prevc))/2.0; 368 | } else { 369 | AUC += c; 370 | } 371 | } 372 | } 373 | 374 | R_minus = (numPos*r - AOC)/(numPos * numNeg); 375 | R_plus = (AUC)/(numPos * numNeg); 376 | return R_plus - R_minus; 377 | } 378 | 379 | public double getRatio() { 380 | if(numNeg == 0) { 381 | return Double.MAX_VALUE; 382 | } else { 383 | return numPos/numNeg; 384 | } 385 | } 386 | 387 | public double getAccuracy() { 388 | if (size > 0) { 389 | return totalObservedInstances > 0.0 ? correctPredictions / Math.min(size, totalObservedInstances) : 0.0; 390 | } else { 391 | return totalObservedInstances > 0.0 ? correctPredictions / totalObservedInstances : 0.0; 392 | } 393 | } 394 | 395 | public double getKappa() { 396 | double p0 = getAccuracy(); 397 | double pc = 0.0; 398 | 399 | if (size > 0) { 400 | for (int i = 0; i < numClasses; i++) { 401 | pc += (this.rowKappa[i]/Math.min(size, totalObservedInstances)) * (this.columnKappa[i]/Math.min(size, totalObservedInstances)); 402 | } 403 | } else { 404 | for (int i = 0; i < numClasses; i++) { 405 | pc += (this.rowKappa[i]/totalObservedInstances) * (this.columnKappa[i]/totalObservedInstances); 406 | } 407 | } 408 | return (p0 - pc) / (1.0 - pc); 409 | } 410 | 411 | private double getKappaM() { 412 | double p0 = getAccuracy(); 413 | double pc = weightMajorityClassifier.estimation(); 414 | 415 | return (p0 - pc) / (1.0 - pc); 416 | } 417 | 418 | public double getGMean() { 419 | double positiveAccuracy = correctPositivePredictions / numPos; 420 | double negativeAccuracy = (correctPredictions - correctPositivePredictions) / numNeg; 421 | return Math.sqrt(positiveAccuracy * negativeAccuracy); 422 | } 423 | 424 | public double getRecall() { 425 | return correctPositivePredictions / numPos; 426 | } 427 | } 428 | 429 | @Override 430 | public void reset() { 431 | reset(this.numClasses); 432 | } 433 | 434 | public void reset(int numClasses) { 435 | if (numClasses != 2) { 436 | throw new RuntimeException( 437 | "Too many classes (" 438 | + numClasses 439 | + "). AUC evaluation can be performed only for two-class problems!"); 440 | } 441 | 442 | this.numClasses = numClasses; 443 | 444 | this.aucEstimator = new Estimator(this.widthOption.getValue()); 445 | this.weightMajorityClassifier = new SimpleEstimator(); 446 | this.totalObservedInstances = 0; 447 | } 448 | 449 | @Override 450 | public void addResult(Example exampleInstance, double[] classVotes) { 451 | InstanceImpl inst = (InstanceImpl)exampleInstance.getData(); 452 | double weight = inst.weight(); 453 | 454 | if (inst.classIsMissing() == false){ 455 | int trueClass = (int) inst.classValue(); 456 | 457 | if (weight > 0.0) { 458 | // // initialize evaluator 459 | if (totalObservedInstances == 0) { 460 | reset(inst.dataset().numClasses()); 461 | } 462 | this.totalObservedInstances += 1; 463 | 464 | //// if classVotes has length == 1, then the negative (0) class got all the votes 465 | Double normalizedVote = 0.0; 466 | 467 | //// normalize and add score 468 | if(classVotes.length == 2) { 469 | normalizedVote = classVotes[1]/(classVotes[0] + classVotes[1]); 470 | } 471 | 472 | if(normalizedVote.isNaN()){ 473 | normalizedVote = 0.0; 474 | } 475 | 476 | this.aucEstimator.add(normalizedVote, trueClass == 1, Utils.maxIndex(classVotes) == trueClass); 477 | this.weightMajorityClassifier.add((this.aucEstimator.getRatio() <= 1 ? 0 : 1) == trueClass ? weight: 0); 478 | } 479 | } 480 | } 481 | 482 | protected double AUC(int[][] confusionMatrix, int positiveClass, int negativeClass) 483 | { 484 | int tp = confusionMatrix[positiveClass][positiveClass]; 485 | int fp = confusionMatrix[positiveClass][negativeClass]; 486 | int tn = confusionMatrix[negativeClass][negativeClass]; 487 | int fn = confusionMatrix[negativeClass][positiveClass]; 488 | 489 | double tpRate = 1.0, fpRate = 0.0; 490 | 491 | if(tp + fn != 0) 492 | tpRate = tp / (double) (tp + fn); 493 | 494 | if(fp + tn != 0) 495 | fpRate = fp / (double) (fp + tn); 496 | 497 | double auc = (1.0 + tpRate - fpRate) / 2.0; 498 | 499 | return auc; 500 | } 501 | 502 | @Override 503 | public Measurement[] getPerformanceMeasurements() { 504 | 505 | Measurement[] measurement = new Measurement[11 + numClasses*numClasses]; 506 | 507 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances); 508 | measurement[1] = new Measurement("AUC", AUC(this.aucEstimator.confusionMatrix, 1, 0)); 509 | measurement[1] = new Measurement("pAUC", this.aucEstimator.getAUC()); 510 | measurement[2] = new Measurement("sAUC", this.aucEstimator.getScoredAUC()); 511 | measurement[3] = new Measurement("Accuracy", this.aucEstimator.getAccuracy()); 512 | measurement[4] = new Measurement("Kappa", this.aucEstimator.getKappa()); 513 | measurement[5] = new Measurement("Periodical holdout AUC", this.aucEstimator.getHoldoutAUC()); 514 | measurement[6] = new Measurement("Pos/Neg ratio", this.aucEstimator.getRatio()); 515 | measurement[7] = new Measurement("G-Mean", this.aucEstimator.getGMean()); 516 | measurement[8] = new Measurement("Recall", this.aucEstimator.getRecall()); 517 | measurement[9] = new Measurement("KappaM", this.aucEstimator.getKappaM()); 518 | measurement[10] = new Measurement("Classes", numClasses); 519 | 520 | for(int i = 0; i < numClasses; i++) { 521 | for(int j = 0; j < numClasses; j++) { 522 | measurement[11 + i*numClasses + j] = new Measurement("CM["+i+"]["+j+"]", this.aucEstimator.confusionMatrix[i][j]); 523 | } 524 | } 525 | 526 | return measurement; 527 | } 528 | 529 | @Override 530 | public void getDescription(StringBuilder sb, int indent) { 531 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(), 532 | sb, indent); 533 | } 534 | 535 | @Override 536 | public void prepareForUseImpl(TaskMonitor monitor, 537 | ObjectRepository repository) { 538 | } 539 | 540 | public Estimator getAucEstimator() { 541 | return aucEstimator; 542 | } 543 | 544 | @Override 545 | public void addResult(Example arg0, Prediction arg1) { 546 | throw new RuntimeException("Designed for scoring classifiers"); 547 | } 548 | } 549 | -------------------------------------------------------------------------------- /src/main/java/moa/evaluation/WindowAUCMultiClassImbalancedPerformanceEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * WindowClassificationPerformanceEvaluator.java 3 | * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand 4 | * @author Albert Bifet (abifet@cs.waikato.ac.nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.evaluation; 21 | 22 | import java.io.Serializable; 23 | import java.util.ArrayList; 24 | import java.util.Iterator; 25 | import java.util.List; 26 | import java.util.TreeSet; 27 | 28 | import moa.core.Example; 29 | import moa.core.Measurement; 30 | import moa.core.ObjectRepository; 31 | import moa.options.AbstractOptionHandler; 32 | import moa.tasks.TaskMonitor; 33 | 34 | import com.github.javacliparser.IntOption; 35 | import com.yahoo.labs.samoa.instances.Instance; 36 | import com.yahoo.labs.samoa.instances.Prediction; 37 | 38 | import weka.core.Utils; 39 | 40 | 41 | /** 42 | * Classification evaluator that updates evaluation results using a sliding 43 | * window. Only to be used for binary classification problems with unweighted instances. 44 | * 45 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 46 | * @version $Revision: 7 $ 47 | */ 48 | public class WindowAUCMultiClassImbalancedPerformanceEvaluator extends AbstractOptionHandler implements ClassificationPerformanceEvaluator, Serializable { 49 | 50 | private static final long serialVersionUID = 1L; 51 | 52 | public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 500); 53 | 54 | protected double totalObservedInstances = 0; 55 | 56 | protected Estimator aucEstimator; 57 | 58 | protected int numClasses; 59 | 60 | public class Estimator implements Serializable { 61 | 62 | public class Score implements Comparable, Serializable { 63 | /** 64 | * Predicted score of the example 65 | */ 66 | protected double[] value; 67 | 68 | /** 69 | * true class label index of the example 70 | */ 71 | protected int realClass; 72 | 73 | /** 74 | * Age of example - position in the window where the example was 75 | * added 76 | */ 77 | protected int posWindow; 78 | 79 | //positive class index for this tree 80 | protected int pos_class; 81 | /** 82 | * Constructor. 83 | * 84 | * @param value 85 | * score value 86 | * @param position 87 | * score position in window (defines its age) 88 | * @param isPositive 89 | * true if the example's true label is positive 90 | */ 91 | public Score(int trueClass, double[] value, int position, int pos_class) { 92 | this.realClass = trueClass; 93 | this.value = value; 94 | this.posWindow = position; 95 | this.pos_class = pos_class; 96 | } 97 | 98 | /** 99 | * Sort descending based on score value of the positive class of the current tree. 100 | */ 101 | @Override 102 | public int compareTo(Score o) { 103 | if (o.value[pos_class] < this.value[pos_class]) { 104 | return -1; 105 | } else if (o.value[pos_class] > this.value[pos_class]){ 106 | return 1; 107 | } else { 108 | if (o.posWindow > this.posWindow) { 109 | return -1; 110 | } else if (o.posWindow < this.posWindow){ 111 | return 1; 112 | } else { 113 | return 0; 114 | } 115 | } 116 | } 117 | 118 | @Override 119 | public boolean equals(Object o) { 120 | return (o instanceof Score) && ((Score)o).posWindow == this.posWindow; 121 | } 122 | 123 | } 124 | 125 | public class RB_tree{ 126 | protected int pos_class; 127 | protected int neg_class; 128 | protected double numPos; 129 | protected double numNeg; 130 | protected TreeSet sortedScores; 131 | 132 | public RB_tree(int pos_class_idx, int neg_class_idx) { 133 | pos_class = pos_class_idx; 134 | neg_class = neg_class_idx; 135 | numPos = 0; 136 | numNeg = 0; 137 | sortedScores = new TreeSet(); 138 | } 139 | } 140 | 141 | protected List rb_trees;//the set of red-black trees for any 2 classes (numClasses*(numClasses-1)) 142 | 143 | protected Score[] window;//store current window of examples 144 | 145 | protected double[] predictions;//store correct/incorrect prediction results for all examples in the window: if the example is correctly classified or not: 1-correct, 0-incorrect. 146 | 147 | protected int[] predictionsClass; 148 | 149 | protected int posWindow;//the position where the oldest example that needs to be removed (if window is full) and where the new example that needs to be added. 150 | 151 | protected int size;//window size 152 | 153 | protected double correctPredictions;//the number of correctly classified examples in the current window 154 | 155 | protected double[] correctPrediction_perclass; //the number of correctly classified examples in the current window, for each class (for calculating recall and g-mean) 156 | 157 | protected double[] totalObservedInstances_perclass_window;//the number of examples in the current window, for each class (for calculating recall and g-mean) 158 | 159 | protected double[] inst_classvotes;//the normalized votes for the current data sample 160 | 161 | protected int inst_trueClass;//the true class of the current data sample 162 | 163 | protected WindowAUCImbalancedPerformanceEvaluator[] auc_2c;//store 2-class version of AUC for each class being the positive class (for calculating Provost's weighted AUC) 164 | 165 | protected WindowAUCImbalancedPerformanceEvaluator[] ewauc_2c;//store 2-class version of AUC for each class being the positive class (for calculating Provost's weighted AUC but with equal weight) 166 | 167 | protected double[] columnKappa; 168 | protected double[] rowKappa; 169 | protected int confusionMatrix[][]; 170 | 171 | public Estimator(int sizeWindow) { 172 | 173 | this.rb_trees = new ArrayList(); 174 | for(int i = 0; i < numClasses-1; i++) { 175 | for(int j = i+1; j < numClasses; j++) { 176 | RB_tree tree_i = new RB_tree(i,j);//a redblack tree with class i as the positive class 177 | RB_tree tree_j = new RB_tree(j,i);//a redblack tree with class j as the positive class 178 | this.rb_trees.add(tree_i); 179 | this.rb_trees.add(tree_j); 180 | } 181 | } 182 | 183 | this.size = sizeWindow; 184 | this.window = new Score[sizeWindow]; 185 | this.predictions = new double[sizeWindow]; 186 | this.predictionsClass = new int[sizeWindow]; 187 | this.correctPrediction_perclass = new double[numClasses]; 188 | this.totalObservedInstances_perclass_window = new double[numClasses]; 189 | this.auc_2c = new WindowAUCImbalancedPerformanceEvaluator[numClasses]; 190 | this.ewauc_2c = new WindowAUCImbalancedPerformanceEvaluator[numClasses]; 191 | 192 | this.rowKappa = new double[numClasses]; 193 | this.columnKappa = new double[numClasses]; 194 | this.confusionMatrix = new int[numClasses][numClasses]; 195 | 196 | this.posWindow = 0; 197 | this.correctPredictions = 0; 198 | for(int i = 0; i < numClasses; i++) { 199 | correctPrediction_perclass[i] = 0; 200 | totalObservedInstances_perclass_window[i] = 0; 201 | 202 | auc_2c[i] = new WindowAUCImbalancedPerformanceEvaluator(); 203 | auc_2c[i].widthOption.setValue(size); 204 | auc_2c[i].reset(2); 205 | 206 | 207 | ewauc_2c[i] = new WindowAUCImbalancedPerformanceEvaluator(); 208 | ewauc_2c[i].widthOption.setValue(size); 209 | ewauc_2c[i].reset(2); 210 | } 211 | 212 | } 213 | 214 | public void add(double[] score, int trueClass, boolean correctPrediction, int predictedClass) { 215 | 216 | this.inst_classvotes = score; 217 | this.inst_trueClass = trueClass; 218 | 219 | int[] tree_idx_add = this.find_trees(trueClass); 220 | // if the window is used and it's full 221 | if (size > 0 && posWindow >= this.size) { 222 | int idx_remove = window[posWindow % size].realClass;//the class label of the example to be removed from the window 223 | int[] tree_idx_remove = this.find_trees(idx_remove);//find the indices of trees containing the class label of the example to be removed from the window 224 | correctPredictions -= predictions[posWindow % size]; 225 | correctPrediction_perclass[idx_remove] -= predictions[posWindow % size]; 226 | totalObservedInstances_perclass_window[idx_remove] -= 1; 227 | for(int i = 0; i < tree_idx_remove.length; i++) { 228 | RB_tree tree_i = this.rb_trees.get(tree_idx_remove[i]); 229 | // remove the oldest example from the tree with "trueClass" 230 | Score node = this.find_node(tree_i, posWindow); 231 | tree_i.sortedScores.remove(node); 232 | 233 | if (window[posWindow % size].realClass == tree_i.pos_class) { 234 | tree_i.numPos--; 235 | } else { 236 | tree_i.numNeg--; 237 | } 238 | } 239 | 240 | this.rowKappa[predictionsClass[posWindow % size]] -= 1; 241 | this.columnKappa[idx_remove] -= 1; 242 | this.confusionMatrix[predictionsClass[posWindow % size]][idx_remove]--; 243 | } 244 | 245 | // add new example 246 | Score newScore = new Score(trueClass, score, posWindow, -1);//new score for updating windows 247 | correctPredictions += correctPrediction ? 1 : 0; 248 | correctPrediction_perclass[trueClass] += correctPrediction ? 1 : 0; 249 | totalObservedInstances_perclass_window[trueClass] += 1; 250 | this.rowKappa[predictedClass] += 1; 251 | this.columnKappa[trueClass] += 1; 252 | this.confusionMatrix[predictedClass][trueClass]++; 253 | 254 | if (size > 0) { 255 | window[posWindow % size] = newScore; 256 | predictions[posWindow % size] = correctPrediction ? 1 : 0; 257 | predictionsClass[posWindow % size] = predictedClass; 258 | } 259 | for(int i = 0; i < tree_idx_add.length; i++) { 260 | RB_tree tree_i = this.rb_trees.get(tree_idx_add[i]); 261 | Score newScore_tree = new Score(trueClass, score, posWindow, tree_i.pos_class);//new score for updating trees 262 | tree_i.sortedScores.add(newScore_tree); 263 | 264 | if (trueClass == tree_i.pos_class) { 265 | tree_i.numPos++; 266 | } else { 267 | tree_i.numNeg++; 268 | } 269 | } 270 | // posWindow needs to be always incremented to differentiate between examples in the red-black tree 271 | posWindow++; 272 | } 273 | 274 | public Score find_node(RB_tree tree_i, int posWindow) { 275 | Score node; 276 | Iterator iterator = tree_i.sortedScores.iterator(); 277 | while(iterator.hasNext()) { 278 | node = iterator.next(); 279 | if((node.posWindow % size) == (posWindow % size)) 280 | return node; 281 | } 282 | return null; 283 | } 284 | 285 | //return the tree indices of those involving class_idx 286 | public int[] find_trees(int class_idx) { 287 | int[] tree_idx = new int[2*(numClasses-1)]; 288 | int t = 0; 289 | for(RB_tree tree_i: this.rb_trees) { 290 | if((tree_i.pos_class == class_idx) || (tree_i.neg_class == class_idx)) { 291 | tree_idx[t] = this.rb_trees.indexOf(tree_i); 292 | t++; 293 | } 294 | } 295 | return tree_idx; 296 | } 297 | 298 | //return the tree with given positive class index and negative class index 299 | public int find_onetree(int pos_class, int neg_class) { 300 | int idx=-1; 301 | for(RB_tree tree_i: this.rb_trees) { 302 | if((tree_i.pos_class == pos_class) && (tree_i.neg_class == neg_class)) { 303 | idx = this.rb_trees.indexOf(tree_i); 304 | break; 305 | } 306 | } 307 | return idx; 308 | } 309 | 310 | public double getPMAUC() { 311 | double pmAUC = 0; 312 | double c = 0; 313 | for(RB_tree tree_i: this.rb_trees) { 314 | c = c + this.getAUC(tree_i); 315 | } 316 | pmAUC = c/(numClasses*(numClasses-1)); 317 | return pmAUC; 318 | } 319 | 320 | //calculate AUC for the given tree 321 | public double getAUC(RB_tree current_tree) { 322 | double AUC = 0; 323 | double c = 0; 324 | if (current_tree.numPos == 0 || current_tree.numNeg == 0) { 325 | return 0; 326 | } 327 | 328 | for (Score s : current_tree.sortedScores){ 329 | if(s.realClass==current_tree.pos_class) { 330 | c += 1; 331 | } else { 332 | AUC += c; 333 | } 334 | } 335 | 336 | return AUC / (current_tree.numPos * current_tree.numNeg); 337 | } 338 | 339 | //calculate AUC based on the given positive class index and negative class index. 340 | public double getAUC(int pos_class, int neg_class) { 341 | double AUC = 0; 342 | double c = 0; 343 | int idx = this.find_onetree(pos_class, neg_class); 344 | RB_tree current_tree = this.rb_trees.get(idx); 345 | 346 | if (current_tree.numPos == 0 || current_tree.numNeg == 0) { 347 | return 0; 348 | } 349 | 350 | for (Score s : current_tree.sortedScores){ 351 | if(s.realClass==pos_class) { 352 | c += 1; 353 | } else { 354 | AUC += c; 355 | } 356 | } 357 | 358 | return AUC / (current_tree.numPos * current_tree.numNeg); 359 | } 360 | 361 | 362 | /** 363 | * Provost and Domingo's weighted AUC [2003Tree Induction for Probability-Based Ranking]: 364 | * compute the expected AUC, which is the weighted average of the AUCs obtained taking each 365 | * class as the reference class in turn (i.e., making it class 0 and all other 366 | * classes class 1). The weight of a class�s AUC is the class�s frequency in the data. 367 | * */ 368 | public double getWeightedAUC() { 369 | double wAUC = 0.0; 370 | double[] class_weights = new double[numClasses]; 371 | 372 | // get class weights 373 | for(int c = 0; c < numClasses; c++) { 374 | if(totalObservedInstances>0 && totalObservedInstances 0) { 448 | return totalObservedInstances > 0.0 ? correctPredictions / Math.min(size, totalObservedInstances) : 0.0; 449 | } else { 450 | return totalObservedInstances > 0.0 ? correctPredictions / totalObservedInstances : 0.0; 451 | } 452 | } 453 | 454 | public double getKappa() { 455 | double p0 = getAccuracy(); 456 | double pc = 0.0; 457 | 458 | if (size > 0) { 459 | for (int i = 0; i < numClasses; i++) { 460 | pc += (this.rowKappa[i]/Math.min(size, totalObservedInstances)) * (this.columnKappa[i]/Math.min(size, totalObservedInstances)); 461 | } 462 | } else { 463 | for (int i = 0; i < numClasses; i++) { 464 | pc += (this.rowKappa[i]/totalObservedInstances) * (this.columnKappa[i]/totalObservedInstances); 465 | } 466 | } 467 | return (p0 - pc) / (1.0 - pc); 468 | } 469 | 470 | //Recall of Class i in the current window 471 | public double getRecall(int class_idx) { 472 | return totalObservedInstances_perclass_window[class_idx] > 0.0 ? correctPrediction_perclass[class_idx] / totalObservedInstances_perclass_window[class_idx] : 0.0; 473 | } 474 | 475 | //G-mean in the current window 476 | public double getGmean() { 477 | double gmean = 1.0; 478 | for(int i = 0; i < numClasses; i++) { 479 | gmean = gmean * this.getRecall(i); 480 | } 481 | gmean = Math.pow(gmean, (double)1/numClasses); 482 | return gmean; 483 | } 484 | 485 | } 486 | 487 | 488 | public void reset(int numClasses) { 489 | this.numClasses = numClasses; 490 | 491 | this.aucEstimator = new Estimator(this.widthOption.getValue()); 492 | this.totalObservedInstances = 0; 493 | } 494 | 495 | public void addResult(Instance inst, double[] classVotes) { 496 | 497 | double weight = inst.weight(); 498 | int trueClass = (int) inst.classValue(); 499 | 500 | if (weight > 0.0) { 501 | // // initialize evaluator 502 | if (totalObservedInstances == 0) { 503 | reset(inst.dataset().numClasses()); 504 | } 505 | this.totalObservedInstances += 1; 506 | 507 | if(classVotes.length != numClasses) { 508 | classVotes = new double[this.numClasses]; 509 | for(int i = 0; i < numClasses; i++) 510 | classVotes[i] = 1.0 / (double) numClasses; 511 | } 512 | 513 | double[] normalizedVote = classVotes.clone();//get a deep copy of classVotes 514 | 515 | for(int i = 0; i < normalizedVote.length; i++) { 516 | if(Double.isNaN(normalizedVote[i])) 517 | normalizedVote[i] = 0.0; 518 | if(Double.isInfinite(normalizedVote[i])) 519 | normalizedVote[i] = Double.MAX_VALUE; 520 | } 521 | 522 | if(Utils.sum(normalizedVote) == 0) { normalizedVote[0] = 1; } 523 | 524 | //// normalize and add score 525 | Utils.normalize(normalizedVote); 526 | 527 | this.aucEstimator.add(normalizedVote, trueClass, Utils.maxIndex(classVotes) == trueClass, Utils.maxIndex(classVotes)); 528 | } 529 | } 530 | 531 | public Measurement[] getPerformanceMeasurements() { 532 | Measurement[] measurement = new Measurement[8 + numClasses*numClasses]; 533 | 534 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances); 535 | measurement[1] = new Measurement("PMAUC", this.aucEstimator.getPMAUC()); 536 | measurement[2] = new Measurement("WMAUC", this.aucEstimator.getWeightedAUC()); 537 | measurement[3] = new Measurement("EWMAUC", this.aucEstimator.getEqualWeightedAUC()); 538 | measurement[4] = new Measurement("Accuracy", this.aucEstimator.getAccuracy()); 539 | measurement[5] = new Measurement("Kappa", this.aucEstimator.getKappa()); 540 | measurement[6] = new Measurement("G-Mean", this.aucEstimator.getGmean()); 541 | measurement[7] = new Measurement("Classes", numClasses); 542 | 543 | for(int i = 0; i < numClasses; i++) { 544 | for(int j = 0; j < numClasses; j++) { 545 | measurement[8 + i*numClasses + j] = new Measurement("CM["+i+"]["+j+"]", this.aucEstimator.confusionMatrix[i][j]); 546 | } 547 | } 548 | 549 | return measurement; 550 | } 551 | 552 | @Override 553 | public void getDescription(StringBuilder sb, int indent) { 554 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(), 555 | sb, indent); 556 | } 557 | 558 | @Override 559 | public void prepareForUseImpl(TaskMonitor monitor, 560 | ObjectRepository repository) { 561 | } 562 | 563 | public Estimator getAucEstimator() { 564 | return aucEstimator; 565 | } 566 | 567 | 568 | @Override 569 | public void reset() { 570 | reset(this.numClasses); 571 | } 572 | 573 | 574 | @Override 575 | public void addResult(Example example, double[] classVotes) { 576 | addResult(example.getData(), classVotes); 577 | } 578 | 579 | @Override 580 | public void addResult(Example testInst, Prediction prediction) { 581 | throw new RuntimeException("Designed for scoring classifiers"); 582 | } 583 | } -------------------------------------------------------------------------------- /src/main/java/moa/evaluation/WindowImbalancedClassificationPerformanceEvaluator.java: -------------------------------------------------------------------------------- 1 | package moa.evaluation; 2 | 3 | import moa.core.Example; 4 | import moa.core.Measurement; 5 | import moa.core.ObjectRepository; 6 | import moa.core.Utils; 7 | import moa.options.AbstractOptionHandler; 8 | 9 | import com.github.javacliparser.IntOption; 10 | import com.yahoo.labs.samoa.instances.Instance; 11 | import com.yahoo.labs.samoa.instances.InstanceImpl; 12 | import com.yahoo.labs.samoa.instances.Prediction; 13 | 14 | import moa.tasks.TaskMonitor; 15 | 16 | /** 17 | * Classification evaluator that updates evaluation results using a sliding window. 18 | * Performance measures designed for class imbalance problems (binary and multiclass). 19 | * 20 | * @author Alberto Cano 21 | */ 22 | public class WindowImbalancedClassificationPerformanceEvaluator extends AbstractOptionHandler implements ClassificationPerformanceEvaluator { 23 | 24 | private static final long serialVersionUID = 1L; 25 | 26 | public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); 27 | 28 | protected int confusionMatrix[][]; 29 | 30 | protected int numClasses; 31 | 32 | protected int totalObservedInstances; 33 | 34 | protected int positionWindow; 35 | 36 | protected int predictionsTrue[]; 37 | 38 | protected int predictionsPredicted[]; 39 | 40 | @Override 41 | public void reset() { 42 | reset(this.numClasses); 43 | } 44 | 45 | public void reset(int numClasses) { 46 | this.numClasses = numClasses; 47 | this.positionWindow = 0; 48 | this.totalObservedInstances = 0; 49 | this.confusionMatrix = new int[numClasses][numClasses]; 50 | this.predictionsTrue = new int[this.widthOption.getValue()]; 51 | this.predictionsPredicted = new int[this.widthOption.getValue()]; 52 | } 53 | 54 | @Override 55 | public void addResult(Example exampleInstance, double[] classVotes) { 56 | InstanceImpl inst = (InstanceImpl)exampleInstance.getData(); 57 | double weight = inst.weight(); 58 | 59 | if (inst.classIsMissing() == false){ 60 | int trueClass = (int) inst.classValue(); 61 | int predictedClass = Utils.maxIndex(classVotes); 62 | 63 | if (weight > 0.0) { 64 | // // initialize evaluator 65 | if (totalObservedInstances == 0) { 66 | reset(inst.dataset().numClasses()); 67 | } 68 | 69 | this.totalObservedInstances++; 70 | 71 | if(totalObservedInstances > this.widthOption.getValue()) 72 | { 73 | this.confusionMatrix[predictionsPredicted[positionWindow]][predictionsTrue[positionWindow]]--; 74 | } 75 | 76 | this.predictionsTrue[positionWindow] = trueClass; 77 | this.predictionsPredicted[positionWindow] = predictedClass; 78 | this.confusionMatrix[predictedClass][trueClass]++; 79 | 80 | positionWindow = (positionWindow + 1) % this.widthOption.getValue(); 81 | } 82 | } 83 | } 84 | 85 | protected double Kappa(int[][] confusionMatrix) 86 | { 87 | int correctedClassified = 0; 88 | int numberInstancesTotal = 0; 89 | int[] numberInstances = new int[numClasses]; 90 | int[] predictedInstances = new int[numClasses]; 91 | 92 | for(int i = 0; i < numClasses; i++) 93 | { 94 | correctedClassified += confusionMatrix[i][i]; 95 | 96 | for(int j = 0; j < numClasses; j++) 97 | { 98 | numberInstances[i] += confusionMatrix[j][i]; 99 | predictedInstances[i] += confusionMatrix[i][j]; 100 | } 101 | 102 | numberInstancesTotal += numberInstances[i]; 103 | } 104 | 105 | double mul = 0; 106 | 107 | for(int i = 0; i < numClasses; i++) 108 | mul += numberInstances[i] * predictedInstances[i]; 109 | 110 | if(numberInstancesTotal*numberInstancesTotal - mul != 0) 111 | return ((numberInstancesTotal * correctedClassified) - mul) / (double) ((numberInstancesTotal*numberInstancesTotal) - mul); 112 | else 113 | return 1.0; 114 | } 115 | 116 | protected double AUC(int[][] confusionMatrix) 117 | { 118 | if(numClasses == 2) 119 | return AUC(confusionMatrix,0,1); // Assumes 0 positive (minority), 1 negative (majority) 120 | else 121 | { 122 | /** Multi-class AUC **/ 123 | double auc = 0.0; 124 | 125 | for(int i = 0; i < numClasses; i++) 126 | for(int j = 0; j < numClasses; j++) 127 | if(i != j) 128 | auc += AUC(confusionMatrix,i,j); 129 | 130 | return auc / (double) (numClasses * (numClasses-1)); 131 | } 132 | } 133 | 134 | protected double AUC(int[][] confusionMatrix, int Class1, int Class2) 135 | { 136 | int tp = confusionMatrix[Class1][Class1]; 137 | int fp = confusionMatrix[Class1][Class2]; 138 | int tn = confusionMatrix[Class2][Class2]; 139 | int fn = confusionMatrix[Class2][Class1]; 140 | 141 | double tpRate = 1.0, fpRate = 0.0; 142 | 143 | if(tp + fn != 0) 144 | tpRate = tp / (double) (tp + fn); 145 | 146 | if(fp + tn != 0) 147 | fpRate = fp / (double) (fp + tn); 148 | 149 | double auc = (1.0 + tpRate - fpRate) / 2.0; 150 | 151 | return auc; 152 | } 153 | 154 | protected double GMean(int[][] confusionMatrix) 155 | { 156 | double gmean = 1.0; 157 | 158 | int[] numberInstances = new int[numClasses]; 159 | 160 | for(int i = 0; i < numClasses; i++) 161 | { 162 | for(int j = 0; j < numClasses; j++) 163 | numberInstances[i] += confusionMatrix[j][i]; 164 | 165 | if(numberInstances[i] != 0) 166 | gmean *= (confusionMatrix[i][i] / (double) numberInstances[i]); 167 | } 168 | 169 | return Math.pow(gmean, 1.0 / (double) numClasses); 170 | } 171 | 172 | protected double Accuracy(int[][] confusionMatrix) 173 | { 174 | int correctedClassified = 0; 175 | int numberInstancesTotal = 0; 176 | 177 | for(int i = 0; i < numClasses; i++) 178 | { 179 | correctedClassified += confusionMatrix[i][i]; 180 | 181 | for(int j = 0; j < numClasses; j++) 182 | numberInstancesTotal += confusionMatrix[i][j]; 183 | } 184 | 185 | return correctedClassified / (double) numberInstancesTotal; 186 | } 187 | 188 | protected double AvgAccuracy(int[][] confusionMatrix) 189 | { 190 | int[] numberInstances = new int[numClasses]; 191 | 192 | for(int i = 0; i < numClasses; i++) 193 | for(int j = 0; j < numClasses; j++) 194 | numberInstances[i] += confusionMatrix[j][i]; 195 | 196 | double avgAccuracy = 0; 197 | int existingClasses = 0; 198 | 199 | for(int i = 0; i < numClasses; i++) 200 | if(numberInstances[i] != 0) 201 | { 202 | avgAccuracy += confusionMatrix[i][i] / (double) numberInstances[i]; 203 | existingClasses++; 204 | } 205 | 206 | return avgAccuracy / (double) existingClasses; 207 | } 208 | 209 | protected double Precision(int[][] confusionMatrix) 210 | { 211 | int[] numberPredictions = new int[numClasses]; 212 | 213 | for(int i = 0; i < numClasses; i++) 214 | for(int j = 0; j < numClasses; j++) 215 | numberPredictions[i] += confusionMatrix[i][j]; 216 | 217 | double precision = 0; 218 | int existingClasses = 0; 219 | 220 | for(int i = 0; i < numClasses; i++) 221 | if(numberPredictions[i] != 0) 222 | { 223 | precision += confusionMatrix[i][i] / (double) numberPredictions[i]; 224 | existingClasses++; 225 | } 226 | 227 | return precision / (double) existingClasses; 228 | } 229 | 230 | protected double Recall(int[][] confusionMatrix) 231 | { 232 | int[] numberInstances = new int[numClasses]; 233 | 234 | for(int i = 0; i < numClasses; i++) 235 | for(int j = 0; j < numClasses; j++) 236 | numberInstances[i] += confusionMatrix[j][i]; 237 | 238 | double recall = 0; 239 | int existingClasses = 0; 240 | 241 | for(int i = 0; i < numClasses; i++) 242 | if(numberInstances[i] != 0) 243 | { 244 | recall += confusionMatrix[i][i] / (double) numberInstances[i]; 245 | existingClasses++; 246 | } 247 | 248 | return recall / (double) existingClasses; 249 | } 250 | 251 | protected double Ratio(int[][] confusionMatrix, int targetClass) 252 | { 253 | int numberInstances = 0; 254 | int numberInstancesTotal = 0; 255 | 256 | for(int i = 0; i < numClasses; i++) 257 | { 258 | numberInstances += confusionMatrix[i][targetClass]; 259 | 260 | for(int j = 0; j < numClasses; j++) 261 | numberInstancesTotal += confusionMatrix[i][j]; 262 | } 263 | 264 | return numberInstances / (double) numberInstancesTotal; 265 | } 266 | 267 | 268 | public double getAccuracy() { 269 | return Accuracy(this.confusionMatrix); 270 | } 271 | 272 | public double getKappa() { 273 | return Kappa(this.confusionMatrix); 274 | } 275 | 276 | @Override 277 | public Measurement[] getPerformanceMeasurements() { 278 | 279 | Measurement[] measurement = new Measurement[8 + numClasses]; 280 | 281 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances); 282 | measurement[1] = new Measurement("Accuracy", Accuracy(this.confusionMatrix)); 283 | measurement[2] = new Measurement("AvgAccuracy", AvgAccuracy(this.confusionMatrix)); 284 | measurement[3] = new Measurement("AUC", AUC(this.confusionMatrix)); 285 | measurement[4] = new Measurement("Kappa", Kappa(this.confusionMatrix)); 286 | measurement[5] = new Measurement("G-Mean", GMean(this.confusionMatrix)); 287 | measurement[6] = new Measurement("Precision", Precision(this.confusionMatrix)); 288 | measurement[7] = new Measurement("Recall", Recall(this.confusionMatrix)); 289 | 290 | for(int i = 0; i < numClasses; i++) 291 | measurement[8 + i] = new Measurement("Ratio-Class-" + i, Ratio(this.confusionMatrix, i)); 292 | 293 | return measurement; 294 | } 295 | 296 | @Override 297 | public void getDescription(StringBuilder sb, int indent) { 298 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, indent); 299 | } 300 | 301 | @Override 302 | public void prepareForUseImpl(TaskMonitor monitor, 303 | ObjectRepository repository) { 304 | } 305 | 306 | @Override 307 | public void addResult(Example arg0, Prediction arg1) { 308 | throw new RuntimeException("Designed for scoring classifiers"); 309 | } 310 | } -------------------------------------------------------------------------------- /src/main/java/moa/streams/filters/AddNoiseFilterFeatures.java: -------------------------------------------------------------------------------- 1 | /* 2 | * AddNoiseFilter.java 3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.filters; 21 | 22 | import java.util.ArrayList; 23 | import java.util.Random; 24 | 25 | import moa.core.AutoExpandVector; 26 | import moa.core.DoubleVector; 27 | import moa.core.GaussianEstimator; 28 | import com.yahoo.labs.samoa.instances.InstancesHeader; 29 | import com.github.javacliparser.FloatOption; 30 | import com.github.javacliparser.IntOption; 31 | import com.yahoo.labs.samoa.instances.Instance; 32 | 33 | /** 34 | * Filter for adding random noise to examples in a stream. 35 | * 36 | * @author Alberto Cano (acano@vcu.edu) 37 | * @version $Revision: 1 $ 38 | */ 39 | public class AddNoiseFilterFeatures extends AbstractStreamFilter { 40 | 41 | @Override 42 | public String getPurposeString() { 43 | return "Adds random noise to features in a stream."; 44 | } 45 | 46 | private static final long serialVersionUID = 1L; 47 | 48 | public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random noise.", 1); 49 | 50 | public FloatOption attNoiseFractionOption = new FloatOption("attNoise", 'f', "Probability of an attribute to be disturbed", 0.2, 0.0, 1.0); 51 | 52 | public FloatOption attValueNoiseFractionOption = new FloatOption("attValueNoise", 'a', "The fraction of attribute values to disturb.", 0.1, 0.0, 1.0); 53 | 54 | protected Random random; 55 | 56 | protected AutoExpandVector attValObservers; 57 | 58 | protected boolean[] disturbeAttribute; 59 | 60 | @Override 61 | protected void restartImpl() { 62 | this.disturbeAttribute = null; 63 | this.random = new Random(this.randomSeedOption.getValue()); 64 | this.attValObservers = new AutoExpandVector(); 65 | } 66 | 67 | @Override 68 | public InstancesHeader getHeader() { 69 | return this.inputStream.getHeader(); 70 | } 71 | 72 | public Instance filterInstance(Instance inst) { 73 | 74 | if(this.disturbeAttribute == null) { 75 | this.disturbeAttribute = new boolean[inst.numAttributes()-1]; 76 | 77 | int numberAttributesDisturbed = (int) Math.ceil((inst.numAttributes()-1) * this.attNoiseFractionOption.getValue()); 78 | 79 | ArrayList attributesPool = new ArrayList(); 80 | 81 | for(int i = 0; i < inst.numAttributes()-1; i++) { 82 | attributesPool.add(i); 83 | } 84 | 85 | for(int i = 0; i < numberAttributesDisturbed; i++) { 86 | this.disturbeAttribute[attributesPool.remove(this.random.nextInt(attributesPool.size()))] = true; 87 | } 88 | } 89 | 90 | for (int i = 0; i < inst.numAttributes() -1 ; i++) { 91 | if(this.disturbeAttribute[i]) { 92 | double noiseFrac = this.attValueNoiseFractionOption.getValue(); 93 | 94 | if (inst.attribute(i).isNominal()) { 95 | DoubleVector obs = (DoubleVector) this.attValObservers.get(i); 96 | if (obs == null) { 97 | obs = new DoubleVector(); 98 | this.attValObservers.set(i, obs); 99 | } 100 | int originalVal = (int) inst.value(i); 101 | if (!inst.isMissing(i)) { 102 | obs.addToValue(originalVal, inst.weight()); 103 | } 104 | if ((this.random.nextDouble() < noiseFrac) && (obs.numNonZeroEntries() > 1)) { 105 | do { 106 | inst.setValue(i, this.random.nextInt(obs.numValues())); 107 | } while (((int) inst.value(i) == originalVal) || (obs.getValue((int) inst.value(i)) == 0.0)); 108 | } 109 | } else { 110 | GaussianEstimator obs = (GaussianEstimator) this.attValObservers.get(i); 111 | if (obs == null) { 112 | obs = new GaussianEstimator(); 113 | this.attValObservers.set(i, obs); 114 | } 115 | obs.addObservation(inst.value(i), inst.weight()); 116 | inst.setValue(i, inst.value(i) + this.random.nextGaussian() * obs.getStdDev() * noiseFrac); 117 | } 118 | } 119 | } 120 | return inst; 121 | } 122 | 123 | @Override 124 | public void getDescription(StringBuilder sb, int indent) { 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/AgrawalGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * AgrawalGenerator.java 3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.yahoo.labs.samoa.instances.Attribute; 23 | import com.yahoo.labs.samoa.instances.DenseInstance; 24 | import moa.core.FastVector; 25 | import com.yahoo.labs.samoa.instances.Instance; 26 | import com.yahoo.labs.samoa.instances.Instances; 27 | 28 | import java.util.Random; 29 | import moa.core.Example; 30 | import moa.core.InstanceExample; 31 | 32 | import com.yahoo.labs.samoa.instances.InstancesHeader; 33 | import moa.core.ObjectRepository; 34 | import moa.options.AbstractOptionHandler; 35 | import com.github.javacliparser.FlagOption; 36 | import com.github.javacliparser.FloatOption; 37 | import com.github.javacliparser.IntOption; 38 | import moa.streams.ExampleStream; 39 | import moa.streams.InstanceStream; 40 | import moa.tasks.TaskMonitor; 41 | 42 | /** 43 | * Stream generator for Agrawal dataset. 44 | * Generator described in paper:
45 | * Rakesh Agrawal, Tomasz Imielinksi, and Arun Swami, 46 | * "Database Mining: A Performance Perspective", 47 | * IEEE Transactions on Knowledge and Data Engineering, 48 | * 5(6), December 1993.

49 | * 50 | * Public C source code available at:
51 | * 52 | * http://www.almaden.ibm.com/cs/projects/iis/hdb/Projects/data_mining/datasets/syndata.html

53 | * 54 | * Notes:
55 | * The built in functions are based on the paper (page 924), 56 | * which turn out to be functions pred20 thru pred29 in the public C implementation. 57 | * Perturbation function works like C implementation rather than description in paper. 58 | * 59 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 60 | * @version $Revision: 7 $ 61 | */ 62 | public class AgrawalGenerator extends AbstractOptionHandler implements 63 | InstanceStream { 64 | 65 | @Override 66 | public String getPurposeString() { 67 | return "Generates one of ten different pre-defined loan functions."; 68 | } 69 | 70 | private static final long serialVersionUID = 1L; 71 | 72 | public IntOption functionOption = new IntOption("function", 'f', 73 | "Classification function used, as defined in the original paper.", 74 | 1, 1, 10); 75 | 76 | public IntOption instanceRandomSeedOption = new IntOption( 77 | "instanceRandomSeed", 'i', 78 | "Seed for random generation of instances.", 1); 79 | 80 | public FloatOption peturbFractionOption = new FloatOption("peturbFraction", 81 | 'p', 82 | "The amount of peturbation (noise) introduced to numeric values.", 83 | 0.05, 0.0, 1.0); 84 | 85 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 86 | "Percentage of minority class examples", 0.1, 0, 1); 87 | 88 | protected interface ClassFunction { 89 | 90 | public int determineClass(double salary, double commission, int age, 91 | int elevel, int car, int zipcode, double hvalue, int hyears, 92 | double loan); 93 | } 94 | 95 | protected static ClassFunction[] classificationFunctions = { 96 | // function 1 97 | new ClassFunction() { 98 | 99 | @Override 100 | public int determineClass(double salary, double commission, 101 | int age, int elevel, int car, int zipcode, 102 | double hvalue, int hyears, double loan) { 103 | return ((age < 40) || (60 <= age)) ? 0 : 1; 104 | } 105 | }, 106 | // function 2 107 | new ClassFunction() { 108 | 109 | @Override 110 | public int determineClass(double salary, double commission, 111 | int age, int elevel, int car, int zipcode, 112 | double hvalue, int hyears, double loan) { 113 | if (age < 40) { 114 | return ((50000 <= salary) && (salary <= 100000)) ? 0 115 | : 1; 116 | } else if (age < 60) {// && age >= 40 117 | return ((75000 <= salary) && (salary <= 125000)) ? 0 118 | : 1; 119 | } else {// age >= 60 120 | return ((25000 <= salary) && (salary <= 75000)) ? 0 : 1; 121 | } 122 | } 123 | }, 124 | // function 3 125 | new ClassFunction() { 126 | 127 | @Override 128 | public int determineClass(double salary, double commission, 129 | int age, int elevel, int car, int zipcode, 130 | double hvalue, int hyears, double loan) { 131 | if (age < 40) { 132 | return ((elevel == 0) || (elevel == 1)) ? 0 : 1; 133 | } else if (age < 60) { // && age >= 40 134 | return ((elevel == 1) || (elevel == 2) || (elevel == 3)) ? 0 135 | : 1; 136 | } else { // age >= 60 137 | return ((elevel == 2) || (elevel == 3) || (elevel == 4)) ? 0 138 | : 1; 139 | } 140 | } 141 | }, 142 | // function 4 143 | new ClassFunction() { 144 | 145 | @Override 146 | public int determineClass(double salary, double commission, 147 | int age, int elevel, int car, int zipcode, 148 | double hvalue, int hyears, double loan) { 149 | if (age < 40) { 150 | if ((elevel == 0) || (elevel == 1)) { 151 | return ((25000 <= salary) && (salary <= 75000)) ? 0 152 | : 1; 153 | } 154 | return ((50000 <= salary) && (salary <= 100000)) ? 0 155 | : 1; 156 | } else if (age < 60) {// && age >= 40 157 | if ((elevel == 1) || (elevel == 2) || (elevel == 3)) { 158 | return ((50000 <= salary) && (salary <= 100000)) ? 0 159 | : 1; 160 | } 161 | return ((75000 <= salary) && (salary <= 125000)) ? 0 162 | : 1; 163 | } else {// age >= 60 164 | if ((elevel == 2) || (elevel == 3) || (elevel == 4)) { 165 | return ((50000 <= salary) && (salary <= 100000)) ? 0 166 | : 1; 167 | } 168 | return ((25000 <= salary) && (salary <= 75000)) ? 0 : 1; 169 | } 170 | } 171 | }, 172 | // function 5 173 | new ClassFunction() { 174 | 175 | @Override 176 | public int determineClass(double salary, double commission, 177 | int age, int elevel, int car, int zipcode, 178 | double hvalue, int hyears, double loan) { 179 | if (age < 40) { 180 | if ((50000 <= salary) && (salary <= 100000)) { 181 | return ((100000 <= loan) && (loan <= 300000)) ? 0 182 | : 1; 183 | } 184 | return ((200000 <= loan) && (loan <= 400000)) ? 0 : 1; 185 | } else if (age < 60) {// && age >= 40 186 | if ((75000 <= salary) && (salary <= 125000)) { 187 | return ((200000 <= loan) && (loan <= 400000)) ? 0 188 | : 1; 189 | } 190 | return ((300000 <= loan) && (loan <= 500000)) ? 0 : 1; 191 | } else {// age >= 60 192 | if ((25000 <= salary) && (salary <= 75000)) { 193 | return ((300000 <= loan) && (loan <= 500000)) ? 0 194 | : 1; 195 | } 196 | return ((100000 <= loan) && (loan <= 300000)) ? 0 : 1; 197 | } 198 | } 199 | }, 200 | // function 6 201 | new ClassFunction() { 202 | 203 | @Override 204 | public int determineClass(double salary, double commission, 205 | int age, int elevel, int car, int zipcode, 206 | double hvalue, int hyears, double loan) { 207 | double totalSalary = salary + commission; 208 | if (age < 40) { 209 | return ((50000 <= totalSalary) && (totalSalary <= 100000)) ? 0 210 | : 1; 211 | } else if (age < 60) {// && age >= 40 212 | return ((75000 <= totalSalary) && (totalSalary <= 125000)) ? 0 213 | : 1; 214 | } else {// age >= 60 215 | return ((25000 <= totalSalary) && (totalSalary <= 75000)) ? 0 216 | : 1; 217 | } 218 | } 219 | }, 220 | // function 7 221 | new ClassFunction() { 222 | 223 | @Override 224 | public int determineClass(double salary, double commission, 225 | int age, int elevel, int car, int zipcode, 226 | double hvalue, int hyears, double loan) { 227 | double disposable = (2.0 * (salary + commission) / 3.0 228 | - loan / 5.0 - 20000.0); 229 | return disposable > 0 ? 0 : 1; 230 | } 231 | }, 232 | // function 8 233 | new ClassFunction() { 234 | 235 | @Override 236 | public int determineClass(double salary, double commission, 237 | int age, int elevel, int car, int zipcode, 238 | double hvalue, int hyears, double loan) { 239 | double disposable = (2.0 * (salary + commission) / 3.0 240 | - 5000.0 * elevel - 20000.0); 241 | return disposable > 0 ? 0 : 1; 242 | } 243 | }, 244 | // function 9 245 | new ClassFunction() { 246 | 247 | @Override 248 | public int determineClass(double salary, double commission, 249 | int age, int elevel, int car, int zipcode, 250 | double hvalue, int hyears, double loan) { 251 | double disposable = (2.0 * (salary + commission) / 3.0 252 | - 5000.0 * elevel - loan / 5.0 - 10000.0); 253 | return disposable > 0 ? 0 : 1; 254 | } 255 | }, 256 | // function 10 257 | new ClassFunction() { 258 | 259 | @Override 260 | public int determineClass(double salary, double commission, 261 | int age, int elevel, int car, int zipcode, 262 | double hvalue, int hyears, double loan) { 263 | double equity = 0.0; 264 | if (hyears >= 20) { 265 | equity = hvalue * (hyears - 20.0) / 10.0; 266 | } 267 | double disposable = (2.0 * (salary + commission) / 3.0 268 | - 5000.0 * elevel + equity / 5.0 - 10000.0); 269 | return disposable > 0 ? 0 : 1; 270 | } 271 | }}; 272 | 273 | protected InstancesHeader streamHeader; 274 | 275 | protected Random instanceRandom; 276 | 277 | protected boolean nextClassShouldBeZero; 278 | 279 | @Override 280 | protected void prepareForUseImpl(TaskMonitor monitor, 281 | ObjectRepository repository) { 282 | // generate header 283 | FastVector attributes = new FastVector(); 284 | attributes.addElement(new Attribute("salary")); 285 | attributes.addElement(new Attribute("commission")); 286 | attributes.addElement(new Attribute("age")); 287 | FastVector elevelLabels = new FastVector(); 288 | for (int i = 0; i < 5; i++) { 289 | elevelLabels.addElement("level" + i); 290 | } 291 | attributes.addElement(new Attribute("elevel", elevelLabels)); 292 | FastVector carLabels = new FastVector(); 293 | for (int i = 0; i < 20; i++) { 294 | carLabels.addElement("car" + (i + 1)); 295 | } 296 | attributes.addElement(new Attribute("car", carLabels)); 297 | FastVector zipCodeLabels = new FastVector(); 298 | for (int i = 0; i < 9; i++) { 299 | zipCodeLabels.addElement("zipcode" + (i + 1)); 300 | } 301 | attributes.addElement(new Attribute("zipcode", zipCodeLabels)); 302 | attributes.addElement(new Attribute("hvalue")); 303 | attributes.addElement(new Attribute("hyears")); 304 | attributes.addElement(new Attribute("loan")); 305 | FastVector classLabels = new FastVector(); 306 | classLabels.addElement("groupA"); 307 | classLabels.addElement("groupB"); 308 | attributes.addElement(new Attribute("class", classLabels)); 309 | this.streamHeader = new InstancesHeader(new Instances( 310 | getCLICreationString(InstanceStream.class), attributes, 0)); 311 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 312 | restart(); 313 | } 314 | 315 | @Override 316 | public long estimatedRemainingInstances() { 317 | return -1; 318 | } 319 | 320 | @Override 321 | public InstancesHeader getHeader() { 322 | return this.streamHeader; 323 | } 324 | 325 | @Override 326 | public boolean hasMoreInstances() { 327 | return true; 328 | } 329 | 330 | @Override 331 | public boolean isRestartable() { 332 | return true; 333 | } 334 | 335 | @Override 336 | public InstanceExample nextInstance() { 337 | double salary = 0, commission = 0, hvalue = 0, loan = 0; 338 | int age = 0, elevel = 0, car = 0, zipcode = 0, hyears = 0, group = 0; 339 | 340 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 341 | 342 | do { 343 | // generate attributes 344 | salary = 20000.0 + 130000.0 * this.instanceRandom.nextDouble(); 345 | commission = (salary >= 75000.0) ? 0 346 | : (10000.0 + 65000.0 * this.instanceRandom.nextDouble()); 347 | // true to c implementation: 348 | // if (instanceRandom.nextDouble() < 0.5 && salary < 75000.0) 349 | // commission = 10000.0 + 65000.0 * instanceRandom.nextDouble(); 350 | age = 20 + this.instanceRandom.nextInt(61); 351 | elevel = this.instanceRandom.nextInt(5); 352 | car = this.instanceRandom.nextInt(20); 353 | zipcode = this.instanceRandom.nextInt(9); 354 | hvalue = (9.0 - zipcode) * 100000.0 355 | * (0.5 + this.instanceRandom.nextDouble()); 356 | hyears = 1 + this.instanceRandom.nextInt(30); 357 | loan = this.instanceRandom.nextDouble() * 500000.0; 358 | // determine class 359 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(salary, commission, age, elevel, car, zipcode, hvalue, hyears, loan); 360 | }while(group != label); 361 | 362 | // perturb values 363 | if (this.peturbFractionOption.getValue() > 0.0) { 364 | salary = perturbValue(salary, 20000, 150000); 365 | if (commission > 0) { 366 | commission = perturbValue(commission, 10000, 75000); 367 | } 368 | age = (int) Math.round(perturbValue(age, 20, 80)); 369 | hvalue = perturbValue(hvalue, (9.0 - zipcode) * 100000.0, 0, 135000); 370 | hyears = (int) Math.round(perturbValue(hyears, 1, 30)); 371 | loan = perturbValue(loan, 0, 500000); 372 | } 373 | // construct instance 374 | InstancesHeader header = getHeader(); 375 | Instance inst = new DenseInstance(header.numAttributes()); 376 | inst.setValue(0, salary); 377 | inst.setValue(1, commission); 378 | inst.setValue(2, age); 379 | inst.setValue(3, elevel); 380 | inst.setValue(4, car); 381 | inst.setValue(5, zipcode); 382 | inst.setValue(6, hvalue); 383 | inst.setValue(7, hyears); 384 | inst.setValue(8, loan); 385 | inst.setDataset(header); 386 | inst.setClassValue(group); 387 | return new InstanceExample(inst); 388 | } 389 | 390 | protected double perturbValue(double val, double min, double max) { 391 | return perturbValue(val, max - min, min, max); 392 | } 393 | 394 | protected double perturbValue(double val, double range, double min, 395 | double max) { 396 | val += range * (2.0 * (this.instanceRandom.nextDouble() - 0.5)) 397 | * this.peturbFractionOption.getValue(); 398 | if (val < min) { 399 | val = min; 400 | } else if (val > max) { 401 | val = max; 402 | } 403 | return val; 404 | } 405 | 406 | @Override 407 | public void restart() { 408 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 409 | this.nextClassShouldBeZero = false; 410 | } 411 | 412 | @Override 413 | public void getDescription(StringBuilder sb, int indent) { 414 | // TODO Auto-generated method stub 415 | } 416 | } 417 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/AssetNegotiationGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * AssetNegotiationGenerator.java 3 | * Copyright (C) 2016 Pontifícia Universidade Católica do Paraná, Curitiba, Brazil 4 | * @author Jean Paul Barddal (jean.barddal@ppgia.pucpr.br) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | 21 | package moa.streams.generators.imbalanced; 22 | 23 | import com.github.javacliparser.FloatOption; 24 | import com.github.javacliparser.IntOption; 25 | import com.yahoo.labs.samoa.instances.Attribute; 26 | import com.yahoo.labs.samoa.instances.DenseInstance; 27 | import com.yahoo.labs.samoa.instances.Instance; 28 | import com.yahoo.labs.samoa.instances.Instances; 29 | import com.yahoo.labs.samoa.instances.InstancesHeader; 30 | import java.util.Arrays; 31 | import java.util.Random; 32 | import moa.core.FastVector; 33 | import moa.core.InstanceExample; 34 | import moa.core.ObjectRepository; 35 | import moa.options.AbstractOptionHandler; 36 | import moa.streams.InstanceStream; 37 | import moa.tasks.TaskMonitor; 38 | 39 | /** 40 | * 41 | * @author Jean Paul Barddal 42 | * @author Fabrício Enembreck 43 | * 44 | * @version 1.0 45 | * @see Originally discussed in F. Enembreck, B. C. Ávila, E. E. Scalabrin & 46 | * J-P. Barthès. LEARNING DRIFTING NEGOTIATIONS. In Applied Artificial 47 | * Intelligence: An International Journal. Volume 21, Issue 9, 2007. DOI: 48 | * 10.1080/08839510701526954 49 | * @see First used in the data stream configuration in J. P. Barddal, H. M. 50 | * Gomes, F. Enembreck, B. Pfahringer & A. Bifet. ON DYNAMIC FEATURE WEIGHTING 51 | * FOR FEATURE DRIFTING DATA STREAMS. In European Conference on Machine Learning 52 | * and Principles and Practice of Knowledge Discovery (ECML/PKDD'16). 2016. 53 | */ 54 | 55 | public class AssetNegotiationGenerator 56 | extends AbstractOptionHandler 57 | implements InstanceStream { 58 | 59 | /* 60 | * OPTIONS 61 | */ 62 | public IntOption functionOption = new IntOption("function", 'f', 63 | "Classification function used, as defined in the original paper.", 64 | 1, 1, 5); 65 | public FloatOption noisePercentage = new FloatOption("noise", 'n', 66 | "% of class noise.", 0.05, 0.0, 1.0f); 67 | 68 | public IntOption instanceRandomSeedOption = new IntOption( 69 | "instanceRandomSeed", 'i', 70 | "Seed for random generation of instances.", 1); 71 | 72 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 73 | "Percentage of minority class examples", 0.1, 0, 1); 74 | 75 | /* 76 | * INTERNALS 77 | */ 78 | protected InstancesHeader streamHeader; 79 | 80 | protected Random instanceRandom; 81 | 82 | protected boolean nextClassShouldBeZero; 83 | 84 | protected ClassFunction classFunction; 85 | 86 | 87 | /* 88 | * FEATURE DEFINITIONS 89 | */ 90 | protected static String colorValues[] = {"black", 91 | "blue", 92 | "cyan", 93 | "brown", 94 | "red", 95 | "green", 96 | "yellow", 97 | "magenta"}; 98 | 99 | protected static String priceValues[] = {"veryLow", 100 | "low", 101 | "normal", 102 | "high", 103 | "veryHigh", 104 | "quiteHigh", 105 | "enormous", 106 | "non_salable"}; 107 | 108 | protected static String paymentValues[] = {"0", 109 | "30", 110 | "60", 111 | "90", 112 | "120", 113 | "150", 114 | "180", 115 | "210", 116 | "240"}; 117 | 118 | protected static String amountValues[] = {"veryLow", 119 | "low", 120 | "normal", 121 | "high", 122 | "veryHigh", 123 | "quiteHigh", 124 | "enormous", 125 | "non_ensured"}; 126 | 127 | protected static String deliveryDelayValues[] = {"veryLow", 128 | "low", 129 | "normal", 130 | "high", 131 | "veryHigh"}; 132 | 133 | protected static String classValues[] = {"interested", "notInterested"}; 134 | 135 | /* 136 | * Labeling functions 137 | */ 138 | protected interface ClassFunction { 139 | 140 | public int determineClass(String color, 141 | String price, 142 | String payment, 143 | String amount, 144 | String deliveryDelay); 145 | 146 | public Instance makeTrue(Instance intnc); 147 | } 148 | 149 | protected static ClassFunction concepts[] = { 150 | new ClassFunction() { 151 | Random r = new Random(Integer.MAX_VALUE); 152 | 153 | @Override 154 | public int determineClass(String color, 155 | String price, 156 | String payment, 157 | String amount, 158 | String deliveryDelay) { 159 | if ((price.equals("normal") && amount.equals("high") 160 | || (color.equals("brown") && price.equals("veryLow") 161 | && deliveryDelay.equals("high")))) { 162 | return indexOfValue("interested", classValues); 163 | } 164 | return indexOfValue("notInterested", classValues); 165 | } 166 | 167 | @Override 168 | public Instance makeTrue(Instance intnc) { 169 | int part = r.nextInt(2); 170 | if (part == 0) { 171 | 172 | intnc.setValue(1, indexOfValue("normal", priceValues)); 173 | intnc.setValue(3, indexOfValue("high", amountValues)); 174 | } else { 175 | intnc.setValue(0, indexOfValue("brown", colorValues)); 176 | intnc.setValue(1, indexOfValue("veryLow", priceValues)); 177 | intnc.setValue(4, indexOfValue("high", deliveryDelayValues)); 178 | } 179 | intnc.setClassValue(indexOfValue("interested", classValues)); 180 | return intnc; 181 | } 182 | 183 | }, 184 | new ClassFunction() { 185 | Random r = new Random(Integer.MAX_VALUE); 186 | 187 | @Override 188 | public int determineClass(String color, 189 | String price, 190 | String payment, 191 | String amount, 192 | String deliveryDelay) { 193 | if (price.equals("high") && amount.equals("veryHigh") 194 | && deliveryDelay.equals("high")) { 195 | return indexOfValue("interested", classValues); 196 | } 197 | return indexOfValue("notInterested", classValues); 198 | } 199 | 200 | @Override 201 | public Instance makeTrue(Instance intnc) { 202 | intnc.setValue(1, indexOfValue("high", priceValues)); 203 | intnc.setValue(3, indexOfValue("veryHigh", amountValues)); 204 | intnc.setValue(4, indexOfValue("high", deliveryDelayValues)); 205 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested")); 206 | return intnc; 207 | } 208 | }, 209 | new ClassFunction() { 210 | Random r = new Random(Integer.MAX_VALUE); 211 | 212 | @Override 213 | public int determineClass(String color, 214 | String price, 215 | String payment, 216 | String amount, 217 | String deliveryDelay) { 218 | if ((price.equals("veryLow") 219 | && payment.equals("0") && amount.equals("high")) 220 | || (color.equals("red") && price.equals("low") 221 | && payment.equals("30"))) { 222 | return indexOfValue("interested", classValues); 223 | } 224 | return indexOfValue("notInterested", classValues); 225 | } 226 | 227 | @Override 228 | public Instance makeTrue(Instance intnc) { 229 | int part = r.nextInt(2); 230 | if (part == 0) { 231 | intnc.setValue(1, indexOfValue("veryLow", priceValues)); 232 | intnc.setValue(2, indexOfValue("0", paymentValues)); 233 | intnc.setValue(3, indexOfValue("high", amountValues)); 234 | } else { 235 | intnc.setValue(0, indexOfValue("red", colorValues)); 236 | intnc.setValue(1, indexOfValue("low", priceValues)); 237 | intnc.setValue(2, indexOfValue("30", paymentValues)); 238 | } 239 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested")); 240 | return intnc; 241 | } 242 | }, 243 | new ClassFunction() { 244 | Random r = new Random(Integer.MAX_VALUE); 245 | 246 | @Override 247 | public int determineClass(String color, 248 | String price, 249 | String payment, 250 | String amount, 251 | String deliveryDelay) { 252 | if ((color.equals("black") 253 | && payment.equals("90") 254 | && deliveryDelay.equals("veryLow")) 255 | || (color.equals("magenta") 256 | && price.equals("high") 257 | && deliveryDelay.equals("veryLow"))) { 258 | return indexOfValue("interested", classValues); 259 | } 260 | return indexOfValue("notInterested", classValues); 261 | } 262 | 263 | @Override 264 | public Instance makeTrue(Instance intnc) { 265 | int part = r.nextInt(2); 266 | if (part == 0) { 267 | intnc.setValue(0, indexOfValue("black", colorValues)); 268 | intnc.setValue(2, indexOfValue("90", paymentValues)); 269 | intnc.setValue(4, indexOfValue("veryLow", deliveryDelayValues)); 270 | } else { 271 | intnc.setValue(0, indexOfValue("magenta", colorValues)); 272 | intnc.setValue(1, indexOfValue("high", priceValues)); 273 | intnc.setValue(4, indexOfValue("veryLow", deliveryDelayValues)); 274 | } 275 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested")); 276 | return intnc; 277 | } 278 | }, 279 | new ClassFunction() { 280 | Random r = new Random(Integer.MAX_VALUE); 281 | 282 | @Override 283 | public int determineClass(String color, 284 | String price, 285 | String payment, 286 | String amount, 287 | String deliveryDelay) { 288 | if ((color.equals("blue") 289 | && payment.equals("60") 290 | && amount.equals("low") 291 | && deliveryDelay.equals("normal")) 292 | || (color.equals("cyan") 293 | && amount.equals("low") 294 | && deliveryDelay.equals("normal"))) { 295 | return indexOfValue("interested", classValues); 296 | } 297 | return indexOfValue("notInterested", classValues); 298 | } 299 | 300 | @Override 301 | public Instance makeTrue(Instance intnc) { 302 | int part = r.nextInt(2); 303 | if (part == 0) { 304 | intnc.setValue(0, indexOfValue("blue", colorValues)); 305 | intnc.setValue(2, indexOfValue("60", paymentValues)); 306 | intnc.setValue(3, indexOfValue("low", amountValues)); 307 | intnc.setValue(4, indexOfValue("normal", deliveryDelayValues)); 308 | } else { 309 | intnc.setValue(0, indexOfValue("cyan", colorValues)); 310 | intnc.setValue(3, indexOfValue("low", amountValues)); 311 | intnc.setValue(4, indexOfValue("normal", deliveryDelayValues)); 312 | } 313 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested")); 314 | return intnc; 315 | } 316 | } 317 | }; 318 | 319 | /* 320 | * Generator core 321 | */ 322 | 323 | @Override 324 | public void getDescription(StringBuilder sb, int indent) { 325 | sb.append("Generates instances based on 5 different concept functions " 326 | + "that describe whether another agent is " 327 | + "interested or not in an item."); 328 | } 329 | 330 | @Override 331 | protected void prepareForUseImpl(TaskMonitor tm, ObjectRepository or) { 332 | 333 | classFunction = concepts[this.functionOption.getValue() - 1]; 334 | 335 | FastVector attributes = new FastVector(); 336 | attributes.addElement(new Attribute("color", 337 | Arrays.asList(colorValues))); 338 | attributes.addElement(new Attribute("price", 339 | Arrays.asList(priceValues))); 340 | attributes.addElement(new Attribute("payment", 341 | Arrays.asList(paymentValues))); 342 | attributes.addElement(new Attribute("amount", 343 | Arrays.asList(amountValues))); 344 | attributes.addElement(new Attribute("deliveryDelay", 345 | Arrays.asList(deliveryDelayValues))); 346 | 347 | this.instanceRandom = new Random(System.currentTimeMillis()); 348 | 349 | FastVector classLabels = new FastVector(); 350 | for (int i = 0; i < classValues.length; i++) { 351 | classLabels.addElement(classValues[i]); 352 | } 353 | 354 | attributes.addElement(new Attribute("class", classLabels)); 355 | this.streamHeader = new InstancesHeader(new Instances( 356 | getCLICreationString(InstanceStream.class), attributes, 0)); 357 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 358 | 359 | restart(); 360 | } 361 | 362 | @Override 363 | public InstancesHeader getHeader() { 364 | return streamHeader; 365 | } 366 | 367 | @Override 368 | public long estimatedRemainingInstances() { 369 | return Integer.MAX_VALUE; 370 | } 371 | 372 | @Override 373 | public boolean hasMoreInstances() { 374 | return true; 375 | } 376 | 377 | @Override 378 | public InstanceExample nextInstance() { 379 | Instance instnc = null; 380 | 381 | int classValue = -1; 382 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 383 | 384 | do{ 385 | //randomize indexes for new instance 386 | int indexColor = this.instanceRandom.nextInt(colorValues.length); 387 | int indexPrice = this.instanceRandom.nextInt(priceValues.length); 388 | int indexPayment = this.instanceRandom.nextInt(paymentValues.length); 389 | int indexAmount = this.instanceRandom.nextInt(amountValues.length); 390 | int indexDelivery = this.instanceRandom.nextInt(deliveryDelayValues.length); 391 | //retrieve values 392 | String color = colorValues[indexColor]; 393 | String price = priceValues[indexPrice]; 394 | String payment = paymentValues[indexPayment]; 395 | String amount = amountValues[indexAmount]; 396 | String delivery = deliveryDelayValues[indexDelivery]; 397 | classValue = classFunction.determineClass(color, price, payment, amount, delivery); 398 | 399 | instnc = new DenseInstance(streamHeader.numAttributes()); 400 | //set values 401 | instnc.setDataset(this.getHeader()); 402 | instnc.setValue(0, Arrays.asList(colorValues).indexOf(color)); 403 | instnc.setValue(1, Arrays.asList(priceValues).indexOf(price)); 404 | instnc.setValue(2, Arrays.asList(paymentValues).indexOf(payment)); 405 | instnc.setValue(3, Arrays.asList(amountValues).indexOf(amount)); 406 | instnc.setValue(4, Arrays.asList(deliveryDelayValues).indexOf(delivery)); 407 | 408 | instnc.setClassValue((int) classValue); 409 | 410 | }while(classValue != label); 411 | 412 | //add noise 413 | int newClassValue = addNoise((int) instnc.classValue()); 414 | instnc.setClassValue(newClassValue); 415 | return new InstanceExample(instnc); 416 | } 417 | 418 | @Override 419 | public boolean isRestartable() { 420 | return true; 421 | } 422 | 423 | @Override 424 | public void restart() { 425 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 426 | this.nextClassShouldBeZero = false; 427 | } 428 | 429 | int addNoise(int classObtained) { 430 | if (this.instanceRandom.nextFloat() <= this.noisePercentage.getValue()) { 431 | classObtained = classObtained == 0 ? 1 : 0; 432 | } 433 | return classObtained; 434 | } 435 | 436 | private static int indexOfValue(String value, Object[] arr) { 437 | int index = Arrays.asList(arr).indexOf(value); 438 | return index; 439 | } 440 | } -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/HyperplaneGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * HyperplaneGenerator.java 3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand 4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import java.util.Arrays; 23 | import java.util.Random; 24 | import com.github.javacliparser.FloatOption; 25 | import com.github.javacliparser.IntOption; 26 | import moa.core.FastVector; 27 | import moa.core.InstanceExample; 28 | import moa.core.ObjectRepository; 29 | import moa.options.AbstractOptionHandler; 30 | import moa.streams.InstanceStream; 31 | import moa.tasks.TaskMonitor; 32 | import com.yahoo.labs.samoa.instances.Attribute; 33 | import com.yahoo.labs.samoa.instances.DenseInstance; 34 | import com.yahoo.labs.samoa.instances.Instance; 35 | import com.yahoo.labs.samoa.instances.Instances; 36 | import com.yahoo.labs.samoa.instances.InstancesHeader; 37 | 38 | /** 39 | * Stream generator for Hyperplane data stream. 40 | * 41 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 42 | * @version $Revision: 7 $ 43 | */ 44 | public class HyperplaneGenerator extends AbstractOptionHandler implements 45 | InstanceStream { 46 | 47 | @Override 48 | public String getPurposeString() { 49 | return "Generates a problem of predicting class of a rotating hyperplane."; 50 | } 51 | 52 | private static final long serialVersionUID = 1L; 53 | 54 | public IntOption instanceRandomSeedOption = new IntOption( 55 | "instanceRandomSeed", 'i', 56 | "Seed for random generation of instances.", 1); 57 | 58 | public IntOption numClassesOption = new IntOption("numClasses", 'c', 59 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE); 60 | 61 | public IntOption numAttsOption = new IntOption("numAtts", 'a', 62 | "The number of attributes to generate.", 10, 0, Integer.MAX_VALUE); 63 | 64 | public IntOption numDriftAttsOption = new IntOption("numDriftAtts", 'k', 65 | "The number of attributes with drift.", 2, 0, Integer.MAX_VALUE); 66 | 67 | public FloatOption magChangeOption = new FloatOption("magChange", 't', 68 | "Magnitude of the change for every example", 0.0, 0.0, 1.0); 69 | 70 | public IntOption noisePercentageOption = new IntOption("noisePercentage", 71 | 'n', "Percentage of noise to add to the data.", 5, 0, 100); 72 | 73 | public IntOption sigmaPercentageOption = new IntOption("sigmaPercentage", 74 | 's', "Percentage of probability that the direction of change is reversed.", 10, 0, 100); 75 | 76 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 77 | "Percentage of minority class examples", 0.1, 0, 1); 78 | 79 | protected InstancesHeader streamHeader; 80 | 81 | protected Random instanceRandom; 82 | 83 | protected double[] weights; 84 | 85 | protected int[] sigma; 86 | 87 | public int numberInstance; 88 | 89 | @Override 90 | protected void prepareForUseImpl(TaskMonitor monitor, 91 | ObjectRepository repository) { 92 | monitor.setCurrentActivity("Preparing hyperplane...", -1.0); 93 | generateHeader(); 94 | restart(); 95 | } 96 | 97 | protected void generateHeader() { 98 | FastVector attributes = new FastVector(); 99 | for (int i = 0; i < this.numAttsOption.getValue(); i++) { 100 | attributes.addElement(new Attribute("att" + (i + 1))); 101 | } 102 | 103 | FastVector classLabels = new FastVector(); 104 | for (int i = 0; i < this.numClassesOption.getValue(); i++) { 105 | classLabels.addElement("class" + (i + 1)); 106 | } 107 | attributes.addElement(new Attribute("class", classLabels)); 108 | this.streamHeader = new InstancesHeader(new Instances( 109 | getCLICreationString(InstanceStream.class), attributes, 0)); 110 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 111 | } 112 | 113 | @Override 114 | public long estimatedRemainingInstances() { 115 | return -1; 116 | } 117 | 118 | @Override 119 | public InstancesHeader getHeader() { 120 | return this.streamHeader; 121 | } 122 | 123 | @Override 124 | public boolean hasMoreInstances() { 125 | return true; 126 | } 127 | 128 | @Override 129 | public boolean isRestartable() { 130 | return true; 131 | } 132 | 133 | @Override 134 | public InstanceExample nextInstance() { 135 | 136 | int numAtts = this.numAttsOption.getValue(); 137 | double[] attVals = new double[numAtts + 1]; 138 | double sum = 0.0; 139 | double sumWeights = 0.0; 140 | int classLabel = -1; 141 | 142 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 143 | 144 | do 145 | { 146 | sum = 0.0; 147 | sumWeights = 0.0; 148 | 149 | for (int i = 0; i < numAtts; i++) { 150 | attVals[i] = this.instanceRandom.nextDouble(); 151 | sum += this.weights[i] * attVals[i]; 152 | sumWeights += this.weights[i]; 153 | } 154 | 155 | if (sum >= sumWeights * 0.5) { 156 | classLabel = 1; 157 | } else { 158 | classLabel = 0; 159 | } 160 | }while(classLabel != label); 161 | 162 | //Add Noise 163 | if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) { 164 | classLabel = (classLabel == 0 ? 1 : 0); 165 | } 166 | 167 | Instance inst = new DenseInstance(1.0, attVals); 168 | inst.setDataset(getHeader()); 169 | inst.setClassValue(classLabel); 170 | addDrift(); 171 | return new InstanceExample(inst); 172 | } 173 | 174 | private void addDrift() { 175 | for (int i = 0; i < this.numDriftAttsOption.getValue(); i++) { 176 | this.weights[i] += (double) ((double) sigma[i]) * ((double) this.magChangeOption.getValue()); 177 | if (//this.weights[i] >= 1.0 || this.weights[i] <= 0.0 || 178 | (1 + (this.instanceRandom.nextInt(100))) <= this.sigmaPercentageOption.getValue()) { 179 | this.sigma[i] *= -1; 180 | } 181 | } 182 | } 183 | 184 | @Override 185 | public void restart() { 186 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 187 | this.weights = new double[this.numAttsOption.getValue()]; 188 | this.sigma = new int[this.numAttsOption.getValue()]; 189 | for (int i = 0; i < this.numAttsOption.getValue(); i++) { 190 | this.weights[i] = this.instanceRandom.nextDouble(); 191 | this.sigma[i] = (i < this.numDriftAttsOption.getValue() ? 1 : 0); 192 | } 193 | } 194 | 195 | @Override 196 | public void getDescription(StringBuilder sb, int indent) { 197 | // TODO Auto-generated method stub 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/MixedGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * MixedGenerator.java 3 | * Copyright (C) 2016 Instituto Federal de Pernambuco 4 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.github.javacliparser.FlagOption; 23 | import com.github.javacliparser.FloatOption; 24 | import com.github.javacliparser.IntOption; 25 | import com.yahoo.labs.samoa.instances.Attribute; 26 | import com.yahoo.labs.samoa.instances.DenseInstance; 27 | import com.yahoo.labs.samoa.instances.Instance; 28 | import com.yahoo.labs.samoa.instances.Instances; 29 | import com.yahoo.labs.samoa.instances.InstancesHeader; 30 | import java.util.ArrayList; 31 | import java.util.List; 32 | import java.util.Random; 33 | import moa.core.InstanceExample; 34 | 35 | import moa.core.ObjectRepository; 36 | import moa.options.AbstractOptionHandler; 37 | import moa.streams.InstanceStream; 38 | import moa.tasks.TaskMonitor; 39 | 40 | /** 41 | * Abrupt concept drift, boolean noise-free examples. Four relevant attributes, 42 | * two boolean attributes v,w and two numeric attributes from [0; 1]. The 43 | * examples are classified positive if two of three conditions are satisfied: 44 | * v,w, y < 0,5 + 0,3 sin(3 * PI * x). After each context change the 45 | * classification is reversed. Proposed by "Gama, Joao, et al. "Learning with 46 | * drift detection." Advances in artificial intelligence–SBIA 2004. Springer 47 | * Berlin Heidelberg, 2004. 286-295." 48 | * 49 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br) 50 | * @version $Revision: 1 $ 51 | */ 52 | public class MixedGenerator extends AbstractOptionHandler implements 53 | InstanceStream { 54 | 55 | public IntOption functionOption = new IntOption("function", 'f', 56 | "Classification function used, as defined in the original paper.", 57 | 1, 1, 2); 58 | 59 | public IntOption instanceRandomSeedOption = new IntOption( 60 | "instanceRandomSeed", 'i', 61 | "Seed for random generation of instances.", 1); 62 | 63 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 64 | "Percentage of minority class examples", 0.1, 0, 1); 65 | 66 | protected InstancesHeader streamHeader; 67 | 68 | protected Random instanceRandom; 69 | 70 | protected boolean nextClassShouldBeZero; 71 | 72 | protected interface ClassFunction { 73 | 74 | public int determineClass(double v, double w, double x, double y); 75 | } 76 | 77 | protected static ClassFunction[] classificationFunctions = { 78 | new ClassFunction() { 79 | @Override 80 | public int determineClass(double v, double w, double x, double y) { 81 | boolean z = y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x); 82 | if ((v == 1 && w == 1) || (v == 1 && z) || (w == 1 && z)) { 83 | return 0; 84 | } else { 85 | return 1; 86 | } 87 | } 88 | }, 89 | new ClassFunction() { 90 | @Override 91 | public int determineClass(double v, double w, double x, double y) { 92 | boolean z = y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x); 93 | if ((v == 1 && w == 1) || (v == 1 && z) || (w == 1 && z)) { 94 | return 1; 95 | } else { 96 | return 0; 97 | } 98 | } 99 | },}; 100 | 101 | @Override 102 | public void getDescription(StringBuilder sb, int indent) { 103 | 104 | } 105 | 106 | @Override 107 | public InstancesHeader getHeader() { 108 | return this.streamHeader; 109 | } 110 | 111 | @Override 112 | public long estimatedRemainingInstances() { 113 | return -1; 114 | } 115 | 116 | @Override 117 | public boolean hasMoreInstances() { 118 | return true; 119 | } 120 | 121 | @Override 122 | public InstanceExample nextInstance() { 123 | double v = 0, w = 0, x = 0, y = 0, group = 0; 124 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 125 | 126 | do { 127 | v = (this.instanceRandom.nextDouble() < 0.5) ? 0 : 1; 128 | w = (this.instanceRandom.nextDouble() < 0.5) ? 0 : 1; 129 | x = this.instanceRandom.nextDouble(); 130 | y = this.instanceRandom.nextDouble(); 131 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(v, w, x, y); 132 | }while(group != label); 133 | 134 | // construct instance 135 | InstancesHeader header = getHeader(); 136 | Instance inst = new DenseInstance(header.numAttributes()); 137 | inst.setValue(0, v); 138 | inst.setValue(1, w); 139 | inst.setValue(2, x); 140 | inst.setValue(3, y); 141 | inst.setDataset(header); 142 | inst.setClassValue(group); 143 | return new InstanceExample(inst); 144 | } 145 | 146 | @Override 147 | public boolean isRestartable() { 148 | return true; 149 | } 150 | 151 | @Override 152 | public void restart() { 153 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 154 | this.nextClassShouldBeZero = false; 155 | } 156 | 157 | @Override 158 | protected void prepareForUseImpl(TaskMonitor monitor, 159 | ObjectRepository repository) { 160 | List booleanLabels = new ArrayList(); 161 | booleanLabels.add("0"); 162 | booleanLabels.add("1"); 163 | 164 | ArrayList attributes = new ArrayList(); 165 | Attribute attribute1 = new Attribute("v", booleanLabels); 166 | Attribute attribute2 = new Attribute("w", booleanLabels); 167 | 168 | Attribute attribute3 = new Attribute("x"); 169 | Attribute attribute4 = new Attribute("y"); 170 | 171 | List classLabels = new ArrayList(); 172 | classLabels.add("positive"); 173 | classLabels.add("negative"); 174 | Attribute classAtt = new Attribute("class", classLabels); 175 | 176 | attributes.add(attribute1); 177 | attributes.add(attribute2); 178 | attributes.add(attribute3); 179 | attributes.add(attribute4); 180 | attributes.add(classAtt); 181 | 182 | this.streamHeader = new InstancesHeader(new Instances( 183 | getCLICreationString(InstanceStream.class), attributes, 0)); 184 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 185 | restart(); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/RandomRBFGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * RandomRBFGenerator.java 3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.yahoo.labs.samoa.instances.Attribute; 23 | import com.yahoo.labs.samoa.instances.DenseInstance; 24 | import moa.core.FastVector; 25 | import com.yahoo.labs.samoa.instances.Instance; 26 | import com.yahoo.labs.samoa.instances.Instances; 27 | 28 | import java.io.Serializable; 29 | import java.util.Random; 30 | import moa.core.InstanceExample; 31 | 32 | import com.yahoo.labs.samoa.instances.InstancesHeader; 33 | import moa.core.MiscUtils; 34 | import moa.core.ObjectRepository; 35 | import moa.options.AbstractOptionHandler; 36 | 37 | import com.github.javacliparser.FloatOption; 38 | import com.github.javacliparser.IntOption; 39 | import moa.streams.InstanceStream; 40 | import moa.tasks.TaskMonitor; 41 | 42 | /** 43 | * Stream generator for a random radial basis function stream. 44 | * 45 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 46 | * @version $Revision: 7 $ 47 | */ 48 | public class RandomRBFGenerator extends AbstractOptionHandler implements 49 | InstanceStream { 50 | 51 | @Override 52 | public String getPurposeString() { 53 | return "Generates a random radial basis function stream."; 54 | } 55 | 56 | private static final long serialVersionUID = 1L; 57 | 58 | public IntOption modelRandomSeedOption = new IntOption("modelRandomSeed", 59 | 'r', "Seed for random generation of model.", 1); 60 | 61 | public IntOption instanceRandomSeedOption = new IntOption( 62 | "instanceRandomSeed", 'i', 63 | "Seed for random generation of instances.", 1); 64 | 65 | public IntOption numClassesOption = new IntOption("numClasses", 'c', 66 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE); 67 | 68 | public IntOption numAttsOption = new IntOption("numAtts", 'a', 69 | "The number of attributes to generate.", 10, 0, Integer.MAX_VALUE); 70 | 71 | public IntOption numCentroidsOption = new IntOption("numCentroids", 'n', 72 | "The number of centroids in the model.", 50, 1, Integer.MAX_VALUE); 73 | 74 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 75 | "Percentage of minority class examples", 0.1, 0, 1); 76 | 77 | protected static class Centroid implements Serializable { 78 | 79 | private static final long serialVersionUID = 1L; 80 | 81 | public double[] centre; 82 | 83 | public int classLabel; 84 | 85 | public double stdDev; 86 | } 87 | 88 | protected InstancesHeader streamHeader; 89 | 90 | protected Centroid[] centroids; 91 | 92 | protected double[] centroidWeights; 93 | 94 | protected Random instanceRandom; 95 | 96 | @Override 97 | public void prepareForUseImpl(TaskMonitor monitor, 98 | ObjectRepository repository) { 99 | monitor.setCurrentActivity("Preparing random RBF...", -1.0); 100 | generateHeader(); 101 | generateCentroids(); 102 | restart(); 103 | } 104 | 105 | @Override 106 | public InstancesHeader getHeader() { 107 | return this.streamHeader; 108 | } 109 | 110 | @Override 111 | public long estimatedRemainingInstances() { 112 | return -1; 113 | } 114 | 115 | @Override 116 | public boolean hasMoreInstances() { 117 | return true; 118 | } 119 | 120 | @Override 121 | public boolean isRestartable() { 122 | return true; 123 | } 124 | 125 | @Override 126 | public void restart() { 127 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 128 | } 129 | 130 | @Override 131 | public InstanceExample nextInstance() { 132 | Centroid centroid; 133 | 134 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 135 | 136 | do{ 137 | centroid = this.centroids[MiscUtils.chooseRandomIndexBasedOnWeights(this.centroidWeights, this.instanceRandom)]; 138 | }while(centroid.classLabel != label); 139 | 140 | int numAtts = this.numAttsOption.getValue(); 141 | double[] attVals = new double[numAtts + 1]; 142 | for (int i = 0; i < numAtts; i++) { 143 | attVals[i] = (this.instanceRandom.nextDouble() * 2.0) - 1.0; 144 | } 145 | double magnitude = 0.0; 146 | for (int i = 0; i < numAtts; i++) { 147 | magnitude += attVals[i] * attVals[i]; 148 | } 149 | magnitude = Math.sqrt(magnitude); 150 | double desiredMag = this.instanceRandom.nextGaussian() * centroid.stdDev; 151 | double scale = desiredMag / magnitude; 152 | for (int i = 0; i < numAtts; i++) { 153 | attVals[i] = centroid.centre[i] + attVals[i] * scale; 154 | } 155 | Instance inst = new DenseInstance(1.0, attVals); 156 | inst.setDataset(getHeader()); 157 | inst.setClassValue(centroid.classLabel); 158 | return new InstanceExample(inst); 159 | } 160 | 161 | protected void generateHeader() { 162 | FastVector attributes = new FastVector(); 163 | for (int i = 0; i < this.numAttsOption.getValue(); i++) { 164 | attributes.addElement(new Attribute("att" + (i + 1))); 165 | } 166 | FastVector classLabels = new FastVector(); 167 | for (int i = 0; i < this.numClassesOption.getValue(); i++) { 168 | classLabels.addElement("class" + (i + 1)); 169 | } 170 | attributes.addElement(new Attribute("class", classLabels)); 171 | this.streamHeader = new InstancesHeader(new Instances( 172 | getCLICreationString(InstanceStream.class), attributes, 0)); 173 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 174 | } 175 | 176 | protected void generateCentroids() { 177 | Random modelRand = new Random(this.modelRandomSeedOption.getValue()); 178 | this.centroids = new Centroid[this.numCentroidsOption.getValue()]; 179 | this.centroidWeights = new double[this.centroids.length]; 180 | for (int i = 0; i < this.centroids.length; i++) { 181 | this.centroids[i] = new Centroid(); 182 | double[] randCentre = new double[this.numAttsOption.getValue()]; 183 | for (int j = 0; j < randCentre.length; j++) { 184 | randCentre[j] = modelRand.nextDouble(); 185 | } 186 | this.centroids[i].centre = randCentre; 187 | this.centroids[i].classLabel = modelRand.nextInt(this.numClassesOption.getValue()); 188 | this.centroids[i].stdDev = modelRand.nextDouble(); 189 | this.centroidWeights[i] = modelRand.nextDouble(); 190 | } 191 | } 192 | 193 | @Override 194 | public void getDescription(StringBuilder sb, int indent) { 195 | // TODO Auto-generated method stub 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/RandomRBFGeneratorDrift.java: -------------------------------------------------------------------------------- 1 | /* 2 | * RandomRBFGeneratorDrift.java 3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand 4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import java.util.Random; 23 | import moa.core.InstanceExample; 24 | 25 | import com.github.javacliparser.IntOption; 26 | import com.github.javacliparser.FloatOption; 27 | import com.yahoo.labs.samoa.instances.Instance; 28 | 29 | /** 30 | * Stream generator for a random radial basis function stream with drift. 31 | * 32 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 33 | * @version $Revision: 7 $ 34 | */ 35 | public class RandomRBFGeneratorDrift extends RandomRBFGenerator { 36 | 37 | @Override 38 | public String getPurposeString() { 39 | return "Generates a random radial basis function stream with drift."; 40 | } 41 | 42 | private static final long serialVersionUID = 1L; 43 | 44 | public FloatOption speedChangeOption = new FloatOption("speedChange", 's', 45 | "Speed of change of centroids in the model.", 0, 0, Float.MAX_VALUE); 46 | 47 | public IntOption numDriftCentroidsOption = new IntOption("numDriftCentroids", 'k', 48 | "The number of centroids with drift.", 50, 0, Integer.MAX_VALUE); 49 | 50 | protected double[][] speedCentroids; 51 | 52 | @Override 53 | public InstanceExample nextInstance() { 54 | //Update Centroids with drift 55 | int len = this.numDriftCentroidsOption.getValue(); 56 | if (len > this.centroids.length) { 57 | len = this.centroids.length; 58 | } 59 | for (int j = 0; j < len; j++) { 60 | for (int i = 0; i < this.numAttsOption.getValue(); i++) { 61 | this.centroids[j].centre[i] += this.speedCentroids[j][i] * this.speedChangeOption.getValue(); 62 | if (this.centroids[j].centre[i] > 1) { 63 | this.centroids[j].centre[i] = 1; 64 | this.speedCentroids[j][i] = -this.speedCentroids[j][i]; 65 | } 66 | if (this.centroids[j].centre[i] < 0) { 67 | this.centroids[j].centre[i] = 0; 68 | this.speedCentroids[j][i] = -this.speedCentroids[j][i]; 69 | } 70 | } 71 | } 72 | return super.nextInstance(); 73 | } 74 | 75 | @Override 76 | protected void generateCentroids() { 77 | super.generateCentroids(); 78 | Random modelRand = new Random(this.modelRandomSeedOption.getValue()); 79 | int len = this.numDriftCentroidsOption.getValue(); 80 | if (len > this.centroids.length) { 81 | len = this.centroids.length; 82 | } 83 | this.speedCentroids = new double[len][this.numAttsOption.getValue()]; 84 | for (int i = 0; i < len; i++) { 85 | double[] randSpeed = new double[this.numAttsOption.getValue()]; 86 | double normSpeed = 0.0; 87 | for (int j = 0; j < randSpeed.length; j++) { 88 | randSpeed[j] = modelRand.nextDouble(); 89 | normSpeed += randSpeed[j] * randSpeed[j]; 90 | } 91 | normSpeed = Math.sqrt(normSpeed); 92 | for (int j = 0; j < randSpeed.length; j++) { 93 | randSpeed[j] /= normSpeed; 94 | } 95 | this.speedCentroids[i] = randSpeed; 96 | } 97 | } 98 | 99 | @Override 100 | public void getDescription(StringBuilder sb, int indent) { 101 | // TODO Auto-generated method stub 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/RandomTreeGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * RandomTreeGenerator.java 3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.yahoo.labs.samoa.instances.Attribute; 23 | import com.yahoo.labs.samoa.instances.DenseInstance; 24 | import moa.core.FastVector; 25 | import com.yahoo.labs.samoa.instances.Instance; 26 | import com.yahoo.labs.samoa.instances.Instances; 27 | 28 | import java.io.Serializable; 29 | import java.util.ArrayList; 30 | import java.util.Random; 31 | import moa.core.InstanceExample; 32 | 33 | import com.yahoo.labs.samoa.instances.InstancesHeader; 34 | import moa.core.ObjectRepository; 35 | import moa.options.AbstractOptionHandler; 36 | import com.github.javacliparser.FloatOption; 37 | import com.github.javacliparser.IntOption; 38 | import moa.streams.InstanceStream; 39 | import moa.tasks.TaskMonitor; 40 | 41 | /** 42 | * Stream generator for a stream based on a randomly generated tree.. 43 | * 44 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) 45 | * @version $Revision: 7 $ 46 | */ 47 | public class RandomTreeGenerator extends AbstractOptionHandler implements 48 | InstanceStream { 49 | 50 | @Override 51 | public String getPurposeString() { 52 | return "Generates a stream based on a randomly generated tree."; 53 | } 54 | 55 | private static final long serialVersionUID = 1L; 56 | 57 | public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed", 58 | 'r', "Seed for random generation of tree.", 1); 59 | 60 | public IntOption instanceRandomSeedOption = new IntOption( 61 | "instanceRandomSeed", 'i', 62 | "Seed for random generation of instances.", 1); 63 | 64 | public IntOption numClassesOption = new IntOption("numClasses", 'c', 65 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE); 66 | 67 | public IntOption numNominalsOption = new IntOption("numNominals", 'o', 68 | "The number of nominal attributes to generate.", 5, 0, 69 | Integer.MAX_VALUE); 70 | 71 | public IntOption numNumericsOption = new IntOption("numNumerics", 'u', 72 | "The number of numeric attributes to generate.", 5, 0, 73 | Integer.MAX_VALUE); 74 | 75 | public IntOption numValsPerNominalOption = new IntOption( 76 | "numValsPerNominal", 'v', 77 | "The number of values to generate per nominal attribute.", 5, 2, 78 | Integer.MAX_VALUE); 79 | 80 | public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd', 81 | "The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE); 82 | 83 | public IntOption firstLeafLevelOption = new IntOption( 84 | "firstLeafLevel", 85 | 'l', 86 | "The first level of the tree above maxTreeDepth that can have leaves.", 87 | 3, 0, Integer.MAX_VALUE); 88 | 89 | public FloatOption leafFractionOption = new FloatOption("leafFraction", 90 | 'f', 91 | "The fraction of leaves per level from firstLeafLevel onwards.", 92 | 0.15, 0.0, 1.0); 93 | 94 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 95 | "Percentage of minority class examples", 0.1, 0, 1); 96 | 97 | protected static class Node implements Serializable { 98 | 99 | private static final long serialVersionUID = 1L; 100 | 101 | public int classLabel; 102 | 103 | public int splitAttIndex; 104 | 105 | public double splitAttValue; 106 | 107 | public Node[] children; 108 | } 109 | 110 | protected Node treeRoot; 111 | 112 | protected InstancesHeader streamHeader; 113 | 114 | protected Random instanceRandom; 115 | 116 | @Override 117 | public void prepareForUseImpl(TaskMonitor monitor, 118 | ObjectRepository repository) { 119 | monitor.setCurrentActivity("Preparing random tree...", -1.0); 120 | generateHeader(); 121 | generateRandomTree(); 122 | restart(); 123 | } 124 | 125 | @Override 126 | public long estimatedRemainingInstances() { 127 | return -1; 128 | } 129 | 130 | @Override 131 | public boolean isRestartable() { 132 | return true; 133 | } 134 | 135 | @Override 136 | public void restart() { 137 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 138 | } 139 | 140 | @Override 141 | public InstancesHeader getHeader() { 142 | return this.streamHeader; 143 | } 144 | 145 | @Override 146 | public boolean hasMoreInstances() { 147 | return true; 148 | } 149 | 150 | @Override 151 | public InstanceExample nextInstance() { 152 | double[] attVals = new double[this.numNominalsOption.getValue() + this.numNumericsOption.getValue()]; 153 | InstancesHeader header = getHeader(); 154 | Instance inst = new DenseInstance(header.numAttributes()); 155 | inst.setDataset(header); 156 | 157 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 158 | 159 | do{ 160 | for (int i = 0; i < attVals.length; i++) { 161 | attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption.getValue()) : this.instanceRandom.nextDouble(); 162 | inst.setValue(i, attVals[i]); 163 | } 164 | inst.setClassValue(classifyInstance(this.treeRoot, attVals)); 165 | }while(label != inst.classValue()); 166 | 167 | return new InstanceExample(inst); 168 | } 169 | 170 | protected int classifyInstance(Node node, double[] attVals) { 171 | if (node.children == null) { 172 | return node.classLabel; 173 | } 174 | if (node.splitAttIndex < this.numNominalsOption.getValue()) { 175 | return classifyInstance( 176 | node.children[(int) attVals[node.splitAttIndex]], attVals); 177 | } 178 | return classifyInstance( 179 | node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0 180 | : 1], attVals); 181 | } 182 | 183 | protected void generateHeader() { 184 | FastVector attributes = new FastVector(); 185 | FastVector nominalAttVals = new FastVector(); 186 | for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) { 187 | nominalAttVals.addElement("value" + (i + 1)); 188 | } 189 | for (int i = 0; i < this.numNominalsOption.getValue(); i++) { 190 | attributes.addElement(new Attribute("nominal" + (i + 1), 191 | nominalAttVals)); 192 | } 193 | for (int i = 0; i < this.numNumericsOption.getValue(); i++) { 194 | attributes.addElement(new Attribute("numeric" + (i + 1))); 195 | } 196 | FastVector classLabels = new FastVector(); 197 | for (int i = 0; i < this.numClassesOption.getValue(); i++) { 198 | classLabels.addElement("class" + (i + 1)); 199 | } 200 | attributes.addElement(new Attribute("class", classLabels)); 201 | this.streamHeader = new InstancesHeader(new Instances( 202 | getCLICreationString(InstanceStream.class), attributes, 0)); 203 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 204 | } 205 | 206 | protected void generateRandomTree() { 207 | Random treeRand = new Random(this.treeRandomSeedOption.getValue()); 208 | ArrayList nominalAttCandidates = new ArrayList( 209 | this.numNominalsOption.getValue()); 210 | for (int i = 0; i < this.numNominalsOption.getValue(); i++) { 211 | nominalAttCandidates.add(i); 212 | } 213 | double[] minNumericVals = new double[this.numNumericsOption.getValue()]; 214 | double[] maxNumericVals = new double[this.numNumericsOption.getValue()]; 215 | for (int i = 0; i < this.numNumericsOption.getValue(); i++) { 216 | minNumericVals[i] = 0.0; 217 | maxNumericVals[i] = 1.0; 218 | } 219 | this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates, 220 | minNumericVals, maxNumericVals, treeRand); 221 | } 222 | 223 | protected Node generateRandomTreeNode(int currentDepth, 224 | ArrayList nominalAttCandidates, double[] minNumericVals, 225 | double[] maxNumericVals, Random treeRand) { 226 | if ((currentDepth >= this.maxTreeDepthOption.getValue()) 227 | || ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand.nextDouble())))) { 228 | Node leaf = new Node(); 229 | leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue()); 230 | return leaf; 231 | } 232 | Node node = new Node(); 233 | int chosenAtt = treeRand.nextInt(nominalAttCandidates.size() 234 | + this.numNumericsOption.getValue()); 235 | if (chosenAtt < nominalAttCandidates.size()) { 236 | node.splitAttIndex = nominalAttCandidates.get(chosenAtt); 237 | node.children = new Node[this.numValsPerNominalOption.getValue()]; 238 | ArrayList newNominalCandidates = new ArrayList( 239 | nominalAttCandidates); 240 | newNominalCandidates.remove(new Integer(node.splitAttIndex)); 241 | newNominalCandidates.trimToSize(); 242 | for (int i = 0; i < node.children.length; i++) { 243 | node.children[i] = generateRandomTreeNode(currentDepth + 1, 244 | newNominalCandidates, minNumericVals, maxNumericVals, 245 | treeRand); 246 | } 247 | } else { 248 | int numericIndex = chosenAtt - nominalAttCandidates.size(); 249 | node.splitAttIndex = this.numNominalsOption.getValue() 250 | + numericIndex; 251 | double minVal = minNumericVals[numericIndex]; 252 | double maxVal = maxNumericVals[numericIndex]; 253 | node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble()) 254 | + minVal; 255 | node.children = new Node[2]; 256 | double[] newMaxVals = maxNumericVals.clone(); 257 | newMaxVals[numericIndex] = node.splitAttValue; 258 | node.children[0] = generateRandomTreeNode(currentDepth + 1, 259 | nominalAttCandidates, minNumericVals, newMaxVals, treeRand); 260 | double[] newMinVals = minNumericVals.clone(); 261 | newMinVals[numericIndex] = node.splitAttValue; 262 | node.children[1] = generateRandomTreeNode(currentDepth + 1, 263 | nominalAttCandidates, newMinVals, maxNumericVals, treeRand); 264 | } 265 | return node; 266 | } 267 | 268 | @Override 269 | public void getDescription(StringBuilder sb, int indent) { 270 | // TODO Auto-generated method stub 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/SEAGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * SEAGenerator.java 3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand 4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.yahoo.labs.samoa.instances.Attribute; 23 | import com.yahoo.labs.samoa.instances.DenseInstance; 24 | import moa.core.FastVector; 25 | import com.yahoo.labs.samoa.instances.Instance; 26 | import com.yahoo.labs.samoa.instances.Instances; 27 | 28 | import java.util.Random; 29 | import moa.core.InstanceExample; 30 | 31 | import com.yahoo.labs.samoa.instances.InstancesHeader; 32 | import moa.core.ObjectRepository; 33 | import moa.options.AbstractOptionHandler; 34 | import com.github.javacliparser.FlagOption; 35 | import com.github.javacliparser.FloatOption; 36 | import com.github.javacliparser.IntOption; 37 | import moa.streams.InstanceStream; 38 | import moa.tasks.TaskMonitor; 39 | 40 | /** 41 | * Stream generator for SEA concepts functions. 42 | * Generator described in the paper:
43 | * W. Nick Street and YongSeog Kim 44 | * "A streaming ensemble algorithm (SEA) for large-scale classification", 45 | * KDD '01: Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining 46 | * 377-382 2001.

47 | * 48 | * Notes:
49 | * The built in functions are based on the paper. 50 | * 51 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 52 | * @version $Revision: 7 $ 53 | */ 54 | public class SEAGenerator extends AbstractOptionHandler implements 55 | InstanceStream { 56 | 57 | @Override 58 | public String getPurposeString() { 59 | return "Generates SEA concepts functions."; 60 | } 61 | 62 | private static final long serialVersionUID = 1L; 63 | 64 | public IntOption functionOption = new IntOption("function", 'f', 65 | "Classification function used, as defined in the original paper.", 66 | 1, 1, 4); 67 | 68 | public IntOption instanceRandomSeedOption = new IntOption( 69 | "instanceRandomSeed", 'i', 70 | "Seed for random generation of instances.", 1); 71 | 72 | public IntOption numInstancesConcept = new IntOption("numInstancesConcept", 'n', 73 | "The number of instances for each concept.", 0, 0, Integer.MAX_VALUE); 74 | 75 | public IntOption noisePercentageOption = new IntOption("noisePercentage", 76 | 'p', "Percentage of noise to add to the data.", 10, 0, 100); 77 | 78 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 79 | "Percentage of minority class examples", 0.1, 0, 1); 80 | 81 | protected interface ClassFunction { 82 | 83 | public int determineClass(double attrib1, double attrib2, double attrib3); 84 | } 85 | 86 | protected static ClassFunction[] classificationFunctions = { 87 | // function 1 88 | new ClassFunction() { 89 | 90 | @Override 91 | public int determineClass(double attrib1, double attrib2, double attrib3) { 92 | return (attrib1 + attrib2 <= 8) ? 0 : 1; 93 | } 94 | }, 95 | // function 2 96 | new ClassFunction() { 97 | 98 | @Override 99 | public int determineClass(double attrib1, double attrib2, double attrib3) { 100 | return (attrib1 + attrib2 <= 9) ? 0 : 1; 101 | } 102 | }, 103 | // function 3 104 | new ClassFunction() { 105 | 106 | public int determineClass(double attrib1, double attrib2, double attrib3) { 107 | return (attrib1 + attrib2 <= 7) ? 0 : 1; 108 | } 109 | }, 110 | // function 4 111 | new ClassFunction() { 112 | 113 | @Override 114 | public int determineClass(double attrib1, double attrib2, double attrib3) { 115 | return (attrib1 + attrib2 <= 9.5) ? 0 : 1; 116 | } 117 | } 118 | }; 119 | 120 | protected InstancesHeader streamHeader; 121 | 122 | protected Random instanceRandom; 123 | 124 | protected boolean nextClassShouldBeZero; 125 | 126 | @Override 127 | protected void prepareForUseImpl(TaskMonitor monitor, 128 | ObjectRepository repository) { 129 | // generate header 130 | FastVector attributes = new FastVector(); 131 | attributes.addElement(new Attribute("attrib1")); 132 | attributes.addElement(new Attribute("attrib2")); 133 | attributes.addElement(new Attribute("attrib3")); 134 | 135 | FastVector classLabels = new FastVector(); 136 | classLabels.addElement("groupA"); 137 | classLabels.addElement("groupB"); 138 | attributes.addElement(new Attribute("class", classLabels)); 139 | this.streamHeader = new InstancesHeader(new Instances( 140 | getCLICreationString(InstanceStream.class), attributes, 0)); 141 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 142 | restart(); 143 | } 144 | 145 | @Override 146 | public long estimatedRemainingInstances() { 147 | return -1; 148 | } 149 | 150 | @Override 151 | public InstancesHeader getHeader() { 152 | return this.streamHeader; 153 | } 154 | 155 | @Override 156 | public boolean hasMoreInstances() { 157 | return true; 158 | } 159 | 160 | @Override 161 | public boolean isRestartable() { 162 | return true; 163 | } 164 | 165 | @Override 166 | public InstanceExample nextInstance() { 167 | double attrib1 = 0, attrib2 = 0, attrib3 = 0; 168 | int group = 0; 169 | 170 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 171 | 172 | do { 173 | // generate attributes 174 | attrib1 = 10 * this.instanceRandom.nextDouble(); 175 | attrib2 = 10 * this.instanceRandom.nextDouble(); 176 | attrib3 = 10 * this.instanceRandom.nextDouble(); 177 | 178 | // determine class 179 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(attrib1, attrib2, attrib3); 180 | }while(group != label); 181 | 182 | //Add Noise 183 | if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) { 184 | group = (group == 0 ? 1 : 0); 185 | } 186 | 187 | // construct instance 188 | InstancesHeader header = getHeader(); 189 | Instance inst = new DenseInstance(header.numAttributes()); 190 | inst.setValue(0, attrib1); 191 | inst.setValue(1, attrib2); 192 | inst.setValue(2, attrib3); 193 | inst.setDataset(header); 194 | inst.setClassValue(group); 195 | return new InstanceExample(inst); 196 | } 197 | 198 | @Override 199 | public void restart() { 200 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 201 | this.nextClassShouldBeZero = false; 202 | } 203 | 204 | @Override 205 | public void getDescription(StringBuilder sb, int indent) { 206 | // TODO Auto-generated method stub 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/STAGGERGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * STAGGERGenerator.java 3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand 4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.yahoo.labs.samoa.instances.Attribute; 23 | import com.yahoo.labs.samoa.instances.DenseInstance; 24 | import moa.core.FastVector; 25 | import com.yahoo.labs.samoa.instances.Instance; 26 | import com.yahoo.labs.samoa.instances.Instances; 27 | 28 | import java.util.Random; 29 | import moa.core.InstanceExample; 30 | 31 | import com.yahoo.labs.samoa.instances.InstancesHeader; 32 | import moa.core.ObjectRepository; 33 | import moa.options.AbstractOptionHandler; 34 | import com.github.javacliparser.FlagOption; 35 | import com.github.javacliparser.FloatOption; 36 | import com.github.javacliparser.IntOption; 37 | import moa.streams.InstanceStream; 38 | import moa.tasks.TaskMonitor; 39 | 40 | /** 41 | * Stream generator for STAGGER Concept functions. 42 | * 43 | * Generator described in the paper:
44 | * Jeffrey C. Schlimmer and Richard H. Granger Jr. 45 | * "Incremental Learning from Noisy Data", 46 | * Machine Learning 1: 317-354 1986.

47 | * 48 | * Notes:
49 | * The built in functions are based on the paper (page 341). 50 | * 51 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 52 | * @version $Revision: 7 $ 53 | */ 54 | public class STAGGERGenerator extends AbstractOptionHandler implements 55 | InstanceStream { 56 | 57 | @Override 58 | public String getPurposeString() { 59 | return "Generates STAGGER Concept functions."; 60 | } 61 | 62 | private static final long serialVersionUID = 1L; 63 | 64 | public IntOption instanceRandomSeedOption = new IntOption( 65 | "instanceRandomSeed", 'i', 66 | "Seed for random generation of instances.", 1); 67 | 68 | public IntOption functionOption = new IntOption("function", 'f', 69 | "Classification function used, as defined in the original paper.", 70 | 1, 1, 3); 71 | 72 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 73 | "Percentage of minority class examples", 0.1, 0, 1); 74 | 75 | protected interface ClassFunction { 76 | 77 | public int determineClass(int size, int color, int shape); 78 | } 79 | 80 | protected static ClassFunction[] classificationFunctions = { 81 | // function 1 82 | new ClassFunction() { 83 | 84 | @Override 85 | public int determineClass(int size, int color, int shape) { 86 | return (size == 0 && color == 0) ? 1 : 0; //size==small && color==red 87 | } 88 | }, 89 | // function 2 90 | new ClassFunction() { 91 | 92 | @Override 93 | public int determineClass(int size, int color, int shape) { 94 | return (color == 2 || shape == 0) ? 1 : 0; //color==green || shape==circle 95 | } 96 | }, 97 | // function 3 98 | new ClassFunction() { 99 | 100 | @Override 101 | public int determineClass(int size, int color, int shape) { 102 | return (size == 1 || size == 2) ? 1 : 0; // size==medium || size==large 103 | } 104 | } 105 | }; 106 | 107 | protected InstancesHeader streamHeader; 108 | 109 | protected Random instanceRandom; 110 | 111 | protected boolean nextClassShouldBeZero; 112 | 113 | @Override 114 | protected void prepareForUseImpl(TaskMonitor monitor, 115 | ObjectRepository repository) { 116 | // generate header 117 | FastVector attributes = new FastVector(); 118 | 119 | FastVector sizeLabels = new FastVector(); 120 | sizeLabels.addElement("small"); 121 | sizeLabels.addElement("medium"); 122 | sizeLabels.addElement("large"); 123 | attributes.addElement(new Attribute("size", sizeLabels)); 124 | 125 | FastVector colorLabels = new FastVector(); 126 | colorLabels.addElement("red"); 127 | colorLabels.addElement("blue"); 128 | colorLabels.addElement("green"); 129 | attributes.addElement(new Attribute("color", colorLabels)); 130 | 131 | FastVector shapeLabels = new FastVector(); 132 | shapeLabels.addElement("circle"); 133 | shapeLabels.addElement("square"); 134 | shapeLabels.addElement("triangle"); 135 | attributes.addElement(new Attribute("shape", shapeLabels)); 136 | 137 | FastVector classLabels = new FastVector(); 138 | classLabels.addElement("false"); 139 | classLabels.addElement("true"); 140 | attributes.addElement(new Attribute("class", classLabels)); 141 | this.streamHeader = new InstancesHeader(new Instances( 142 | getCLICreationString(InstanceStream.class), attributes, 0)); 143 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 144 | restart(); 145 | } 146 | 147 | @Override 148 | public long estimatedRemainingInstances() { 149 | return -1; 150 | } 151 | 152 | @Override 153 | public InstancesHeader getHeader() { 154 | return this.streamHeader; 155 | } 156 | 157 | @Override 158 | public boolean hasMoreInstances() { 159 | return true; 160 | } 161 | 162 | @Override 163 | public boolean isRestartable() { 164 | return true; 165 | } 166 | 167 | @Override 168 | public InstanceExample nextInstance() { 169 | 170 | int size = 0, color = 0, shape = 0, group = 0; 171 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 172 | 173 | do { 174 | // generate attributes 175 | size = this.instanceRandom.nextInt(3); 176 | color = this.instanceRandom.nextInt(3); 177 | shape = this.instanceRandom.nextInt(3); 178 | 179 | // determine class 180 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(size, color, shape); 181 | }while(group != label); 182 | 183 | // construct instance 184 | InstancesHeader header = getHeader(); 185 | Instance inst = new DenseInstance(header.numAttributes()); 186 | inst.setValue(0, size); 187 | inst.setValue(1, color); 188 | inst.setValue(2, shape); 189 | inst.setDataset(header); 190 | inst.setClassValue(group); 191 | return new InstanceExample(inst); 192 | } 193 | 194 | @Override 195 | public void restart() { 196 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 197 | this.nextClassShouldBeZero = false; 198 | } 199 | 200 | @Override 201 | public void getDescription(StringBuilder sb, int indent) { 202 | // TODO Auto-generated method stub 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/SineGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * SineGenerator.java 3 | * Copyright (C) 2016 Instituto Federal de Pernambuco 4 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br) 5 | * 6 | * This program is free software; you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation; either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * This program is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with this program. If not, see . 18 | * 19 | */ 20 | package moa.streams.generators.imbalanced; 21 | 22 | import com.github.javacliparser.FlagOption; 23 | import com.github.javacliparser.FloatOption; 24 | import com.github.javacliparser.IntOption; 25 | import com.yahoo.labs.samoa.instances.Attribute; 26 | import com.yahoo.labs.samoa.instances.DenseInstance; 27 | import com.yahoo.labs.samoa.instances.Instance; 28 | import com.yahoo.labs.samoa.instances.Instances; 29 | import com.yahoo.labs.samoa.instances.InstancesHeader; 30 | import java.util.ArrayList; 31 | import java.util.List; 32 | import java.util.Random; 33 | import moa.core.InstanceExample; 34 | 35 | import moa.core.ObjectRepository; 36 | import moa.options.AbstractOptionHandler; 37 | import moa.streams.InstanceStream; 38 | import moa.tasks.TaskMonitor; 39 | 40 | /** 41 | * 1.SINE1. Abrupt concept drift, noise-free examples. It has two relevant 42 | * attributes. Each attributes has values uniformly distributed in [0; 1]. In 43 | * the first context all points below the curve y = sin(x) are classified as 44 | * positive. After the context change the classification is reversed. 45 | * 2.SINE2. The same two relevant attributes. The classification function is 46 | * y < 0.5 + 0.3 sin(3 * PI * x). After the context change the classification 47 | * is reversed. 48 | * 3.SINIRREL1. Presence of irrelevant attributes. The same classification 49 | * function of SINE1 but the examples have two more random attributes 50 | * with no influence on the classification function. 51 | * 4.SINIRREL2. The same classification function of SINE2 but the examples 52 | * have two more random attributes with no influence on the classification 53 | * function. 54 | * Based on proposal by "Gama, Joao, et al. "Learning with drift 55 | * detection." Advances in artificial intelligence–SBIA 2004. Springer Berlin 56 | * Heidelberg, 2004. 286-295." 57 | * 58 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br) 59 | * @version $Revision: 1 $ 60 | */ 61 | public class SineGenerator extends AbstractOptionHandler implements 62 | InstanceStream { 63 | 64 | public static final int NUM_IRRELEVANT_ATTRIBUTES = 2; 65 | 66 | public IntOption instanceRandomSeedOption = new IntOption( 67 | "instanceRandomSeed", 'i', 68 | "Seed for random generation of instances.", 1); 69 | 70 | public IntOption functionOption = new IntOption("function", 'f', 71 | "Classification function used, as defined in the original paper.", 72 | 1, 1, 4); 73 | 74 | public FlagOption suppressIrrelevantAttributesOption = new FlagOption( 75 | "suppressIrrelevantAttributes", 's', 76 | "Reduce the data to only contain 2 relevant numeric attributes."); 77 | 78 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 79 | "Percentage of minority class examples", 0.1, 0, 1); 80 | 81 | protected InstancesHeader streamHeader; 82 | 83 | protected Random instanceRandom; 84 | 85 | protected boolean nextClassShouldBeZero; 86 | 87 | protected interface ClassFunction { 88 | 89 | public int determineClass(double x, double y); 90 | } 91 | 92 | protected static ClassFunction[] classificationFunctions = { 93 | // Values below the curve y = sin(x) are classified as positive. 94 | new ClassFunction() { 95 | 96 | @Override 97 | public int determineClass(double x, double y) { 98 | return (y < Math.sin(x)) ? 0 : 1; 99 | } 100 | }, 101 | // Values below the curve y = sin(x) are classified as negative. 102 | new ClassFunction() { 103 | 104 | @Override 105 | public int determineClass(double x, double y) { 106 | return (y >= Math.sin(x)) ? 0 : 1; 107 | } 108 | }, 109 | // Values below the curve y = 0.5 + 0.3*sin(3*PI*x) are classified as positive. 110 | new ClassFunction() { 111 | 112 | @Override 113 | public int determineClass(double x, double y) { 114 | return (y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x)) ? 0 : 1; 115 | } 116 | }, 117 | // Values below the curve y = 0.5 + 0.3*sin(3*PI*x) are classified as negative. 118 | new ClassFunction() { 119 | 120 | @Override 121 | public int determineClass(double x, double y) { 122 | return (y >= 0.5 + 0.3 * Math.sin(3 * Math.PI * x)) ? 0 : 1; 123 | } 124 | },}; 125 | 126 | @Override 127 | public void getDescription(StringBuilder sb, int indent) { 128 | 129 | } 130 | 131 | @Override 132 | public InstancesHeader getHeader() { 133 | return this.streamHeader; 134 | } 135 | 136 | @Override 137 | public long estimatedRemainingInstances() { 138 | return -1; 139 | } 140 | 141 | @Override 142 | public boolean hasMoreInstances() { 143 | return true; 144 | } 145 | 146 | @Override 147 | public InstanceExample nextInstance() { 148 | double a1 = 0, a2 = 0, group = 0; 149 | 150 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 151 | 152 | do { 153 | a1 = this.instanceRandom.nextDouble(); 154 | a2 = this.instanceRandom.nextDouble(); 155 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(a1, a2); 156 | }while(group != label); 157 | 158 | // construct instance 159 | InstancesHeader header = getHeader(); 160 | Instance inst = new DenseInstance(header.numAttributes()); 161 | inst.setValue(0, a1); 162 | inst.setValue(1, a2); 163 | inst.setDataset(header); 164 | if (!this.suppressIrrelevantAttributesOption.isSet()) { 165 | for (int i = 0; i < NUM_IRRELEVANT_ATTRIBUTES; i++) { 166 | inst.setValue(i + 2, this.instanceRandom.nextDouble()); 167 | } 168 | } 169 | inst.setClassValue(group); 170 | return new InstanceExample(inst); 171 | } 172 | 173 | @Override 174 | public boolean isRestartable() { 175 | return true; 176 | } 177 | 178 | @Override 179 | public void restart() { 180 | this.instanceRandom = new Random( 181 | this.instanceRandomSeedOption.getValue()); 182 | this.nextClassShouldBeZero = false; 183 | } 184 | 185 | @Override 186 | protected void prepareForUseImpl(TaskMonitor monitor, 187 | ObjectRepository repository) { 188 | ArrayList attributes = new ArrayList(); 189 | 190 | int numAtts = 2; 191 | if (!this.suppressIrrelevantAttributesOption.isSet()) { 192 | numAtts += NUM_IRRELEVANT_ATTRIBUTES; 193 | } 194 | for (int i = 0; i < numAtts; i++) { 195 | attributes.add(new Attribute("att" + (i + 1))); 196 | } 197 | 198 | List classLabels = new ArrayList(); 199 | classLabels.add("positive"); 200 | classLabels.add("negative"); 201 | Attribute classAtt = new Attribute("class", classLabels); 202 | attributes.add(classAtt); 203 | 204 | this.streamHeader = new InstancesHeader(new Instances( 205 | getCLICreationString(InstanceStream.class), attributes, 0)); 206 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 207 | restart(); 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /src/main/java/moa/streams/generators/imbalanced/TextGenerator.java: -------------------------------------------------------------------------------- 1 | package moa.streams.generators.imbalanced; 2 | 3 | import com.github.javacliparser.FloatOption; 4 | 5 | /* 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | * 19 | */ 20 | 21 | import com.github.javacliparser.IntOption; 22 | import com.yahoo.labs.samoa.instances.*; 23 | import moa.core.InstanceExample; 24 | import moa.core.ObjectRepository; 25 | import moa.options.AbstractOptionHandler; 26 | import moa.streams.InstanceStream; 27 | import moa.tasks.TaskMonitor; 28 | 29 | import java.util.ArrayList; 30 | import java.util.Random; 31 | 32 | /** 33 | * Text generator that simulates sentiment analysis on tweets. 34 | */ 35 | public class TextGenerator extends AbstractOptionHandler implements InstanceStream { 36 | 37 | private static final long serialVersionUID = 3028905554604259131L; 38 | 39 | public IntOption numAttsOption = new IntOption("numAtts", 'a', 40 | "The number of attributes to generate.", 1000, 0, Integer.MAX_VALUE); 41 | 42 | public IntOption instanceRandomSeedOption = new IntOption( 43 | "instanceRandomSeed", 'i', 44 | "Seed for random generation of instances.", 1); 45 | 46 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm', 47 | "Percentage of minority class examples", 0.1, 0, 1); 48 | 49 | protected InstancesHeader streamHeader; 50 | 51 | protected Random instanceRandom; 52 | 53 | protected int[] wordTwitterGenerator; 54 | protected double[] freqTwitterGenerator; 55 | protected double[] sumFreqTwitterGenerator; 56 | protected int[] classTwitterGenerator; 57 | 58 | protected int sizeTable; 59 | protected double probPositive = 0.1; 60 | protected double probNegative = 0.1; 61 | protected double zipfExponent = 1.5; 62 | protected double lengthTweet = 15; 63 | 64 | protected int countTweets = 0; 65 | 66 | @Override 67 | public InstancesHeader getHeader() { 68 | return this.streamHeader; 69 | } 70 | 71 | @Override 72 | public long estimatedRemainingInstances() { 73 | return -1; 74 | } 75 | 76 | @Override 77 | public boolean hasMoreInstances() { 78 | return true; 79 | } 80 | 81 | @Override 82 | public InstanceExample nextInstance() { 83 | int[] votes; 84 | double[] attVals; 85 | attVals = new double[this.numAttsOption.getValue() + 1]; 86 | Instance inst; 87 | 88 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0; 89 | 90 | do { 91 | do { 92 | int length = (int) (lengthTweet * (1.0 + this.instanceRandom.nextGaussian())); 93 | if (length < 1) length = 1; 94 | votes = new int[3]; 95 | for (int j = 0; j < length; j++) { 96 | double rand = this.instanceRandom.nextDouble(); 97 | //binary search 98 | int i = 0; 99 | int min = 0; 100 | int max = sizeTable - 1; 101 | int mid; 102 | do { 103 | mid = (min + max) / 2; 104 | if (rand > this.sumFreqTwitterGenerator[mid]) { 105 | min = mid + 1; 106 | } else { 107 | max = mid - 1; 108 | } 109 | } while ((this.sumFreqTwitterGenerator[mid] != rand) && (min <= max)); 110 | 111 | attVals[this.wordTwitterGenerator[mid]] = 1; 112 | votes[this.classTwitterGenerator[mid]]++; 113 | 114 | } 115 | } while (votes[1] == votes[2]); 116 | 117 | inst = new SparseInstance(1.0, attVals); 118 | inst.setDataset(getHeader()); 119 | inst.setClassValue((votes[1] > votes[2]) ? 0 : 1); 120 | }while(inst.classValue() != label); 121 | 122 | this.countTweets++; 123 | return new InstanceExample(inst); 124 | } 125 | 126 | @Override 127 | public boolean isRestartable() { 128 | return true; 129 | } 130 | 131 | @Override 132 | public void restart() { 133 | 134 | this.sizeTable = this.numAttsOption.getValue(); 135 | 136 | //Prepare table of words to generate tweets 137 | this.wordTwitterGenerator = new int[sizeTable]; 138 | this.freqTwitterGenerator = new double[sizeTable]; 139 | this.sumFreqTwitterGenerator = new double[sizeTable]; 140 | this.classTwitterGenerator = new int[sizeTable]; 141 | 142 | this.countTweets = 0; 143 | 144 | double sum = 0; 145 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); 146 | for (int i = 0; i < this.sizeTable; i++) { 147 | this.wordTwitterGenerator[i] = i + 1; 148 | this.freqTwitterGenerator[i] = 1.0 / Math.pow(i + 1, zipfExponent); 149 | sum += this.freqTwitterGenerator[i]; 150 | this.sumFreqTwitterGenerator[i] = sum; 151 | double rand = this.instanceRandom.nextDouble(); 152 | this.classTwitterGenerator[i] = (rand < probPositive ? 1 : (rand < probNegative + probPositive ? 2 : 0)); 153 | } 154 | for (int i = 0; i < this.sizeTable; i++) { 155 | this.freqTwitterGenerator[i] /= sum; 156 | this.sumFreqTwitterGenerator[i] /= sum; 157 | } 158 | 159 | } 160 | 161 | @Override 162 | protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { 163 | generateHeader(); 164 | restart(); 165 | } 166 | 167 | @Override 168 | public void getDescription(StringBuilder sb, int indent) { 169 | 170 | } 171 | private void generateHeader() { 172 | ArrayList classLabels = new ArrayList(); 173 | for (int i = 0; i < 2; i++) { 174 | classLabels.add("class" + (i + 1)); 175 | } 176 | ArrayList attributes = new ArrayList(); 177 | for (int i = 0; i < this.numAttsOption.getValue(); i++) { 178 | attributes.add(new Attribute("att" + (i + 1), classLabels)); 179 | } 180 | attributes.add(new Attribute("class", classLabels)); 181 | this.streamHeader = new InstancesHeader(new Instances( 182 | getCLICreationString(InstanceStream.class), attributes, 0)); 183 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); 184 | } 185 | 186 | 187 | public void changePolarity(int numberWords) { 188 | for (int i = 0; i < numberWords; ) { 189 | int randWord = this.instanceRandom.nextInt(this.sizeTable); 190 | int polarity = this.classTwitterGenerator[randWord]; 191 | if (polarity == 1) { 192 | this.classTwitterGenerator[i] = 2; 193 | i++; 194 | } 195 | if (polarity == 2) { 196 | this.classTwitterGenerator[i] = 1; 197 | i++; 198 | } 199 | } 200 | } 201 | 202 | public void changeFreqWords(int numberWords) { 203 | for (int i = 0; i < numberWords; i++) { 204 | int randWordTo = this.instanceRandom.nextInt(this.sizeTable); 205 | int randWordFrom = this.instanceRandom.nextInt(this.sizeTable); 206 | this.wordTwitterGenerator[randWordTo] = randWordFrom; 207 | this.wordTwitterGenerator[randWordFrom] = randWordTo; 208 | } 209 | } 210 | 211 | 212 | } 213 | --------------------------------------------------------------------------------