├── .classpath
├── .gitignore
├── .project
├── LICENSE
├── MOA-dependencies.jar
├── README.md
├── ROSE-1.0.jar
├── pom.xml
├── sizeofag-1.0.4.jar
└── src
└── main
└── java
├── experiments
├── Collect_Results.java
├── Datasets.java
├── Drifting_Imbalance_Ratio.java
├── Drifting_Noise_and_Imbalance_Ratio.java
├── Instance_Level_Difficulties.java
└── Static_Imbalance_Ratio.java
└── moa
├── classifiers
├── meta
│ └── imbalanced
│ │ └── ROSE.java
└── trees
│ └── RandomSubspaceHT.java
├── evaluation
├── WindowAUCImbalancedPerformanceEvaluator.java
├── WindowAUCMultiClassImbalancedPerformanceEvaluator.java
└── WindowImbalancedClassificationPerformanceEvaluator.java
└── streams
├── filters
└── AddNoiseFilterFeatures.java
└── generators
└── imbalanced
├── AgrawalGenerator.java
├── AssetNegotiationGenerator.java
├── HyperplaneGenerator.java
├── MixedGenerator.java
├── RandomRBFGenerator.java
├── RandomRBFGeneratorDrift.java
├── RandomTreeGenerator.java
├── SEAGenerator.java
├── STAGGERGenerator.java
├── SineGenerator.java
└── TextGenerator.java
/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target/
2 | /datasets/
3 |
--------------------------------------------------------------------------------
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | ROSE
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 | org.eclipse.m2e.core.maven2Builder
15 |
16 |
17 |
18 |
19 |
20 | org.eclipse.jdt.core.javanature
21 | org.eclipse.m2e.core.maven2Nature
22 |
23 |
24 |
--------------------------------------------------------------------------------
/MOA-dependencies.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/MOA-dependencies.jar
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning from Imbalanced Drifting Data Streams
2 |
3 |
Data streams are potentially unbounded sequences of instances arriving over time to a classifier. Designing algorithms that are capable of dealing with massive, rapidly arriving information is one of the most dynamically developing areas of machine learning. Such learners must be able to deal with a phenomenon known as concept drift, where the data stream may be subject to various changes in its characteristics over time. Furthermore, distributions of classes may evolve over time, leading to a highly difficult non-stationary class imbalance. In this work we introduce Robust Online Self-Adjusting Ensemble (ROSE), a novel online ensemble classifier capable of dealing with all of the mentioned challenges. The main features of ROSE are: (i) online training of base classifiers on variable size random subsets of features; (ii) online detection of concept drift and creation of a background ensemble for faster adaptation to changes; (iii) sliding window per class to create skew-insensitive classifiers regardless of the current imbalance ratio; and (iv) self-adjusting bagging to enhance the exposure of difficult instances from minority classes. The interplay among these features leads to an improved performance in various data stream mining benchmarks. An extensive experimental study comparing with 30 ensemble classifiers shows that ROSE is a robust and well-rounded classifier for drifting imbalanced data streams, especially under the presence of noise and class imbalance drift, while maintaining competitive time complexity and memory consumption. Results are supported by a thorough non-parametric statistical analysis.
4 |
5 | ## Using ROSE
6 |
7 | Download the pre-compiled jar files or import the project source code into [MOA](https://github.com/Waikato/moa). See the src/main/java/experiments folder to reproduce our research.
8 |
9 | ### Experiment 1: Static imbalance ratio
10 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100.
11 | ```
12 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 1 -m 0.01)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv
13 | ```
14 |
15 | | Generator | Instances | Features | Classes | Static Imbalance Ratios | Concept Drift |
16 | | -------- | ---: | ---: | ---: |:-: |:-: |
17 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 50, 100} | None |
18 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 50, 100} | None |
19 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 50, 100} | None |
20 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 50, 100} | None |
21 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 50, 100} | None |
22 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 50, 100} | None |
23 |
24 | ### Experiment 2: Drifting imbalance ratio
25 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100. Generate drifting imbalance ratios by chaining `ConceptDriftStream` streams with different imbalance ratios. The parameter `-p` controls the position of the drift and `-w` the width of the drift (sudden vs gradual). The example shows a sequence of increasing then decreasing imbalance ratio ({5, 10, 20, 100, 20, 10, 5}).
26 | ```
27 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 -d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 -d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv
28 | ```
29 |
30 | | Generator | Instances | Features | Classes | Drifting imbalance ratios | Concept Drift |
31 | | -------- | ---: | ---: | ---: |:-: |:-: |
32 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
33 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
34 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
35 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
36 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
37 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} |
38 |
39 | ### Experiment 3: Instance-level difficulties
40 | Use any algorithm in `moa.classifiers` and dataset for instance-level difficulties generated using these imbalanced generators
41 |
42 | ```
43 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ArffFileStream -f Split5+Im1+Borderline20+Rare20.arff)" -l "(moa.classifiers.meta.OSAKUE_20200935_t99_scaleacckappa)" -f 500 -d results.csv
44 | ```
45 |
46 | | Generator | Instances | Features | Classes | Static Imbalance Ratios | Percentage of difficult instances |
47 | | -------- | ---: | ---: | ---: |:-: |:-- |
48 | | Borderline | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%, 60%, 80%, 100%} |
49 | | Rare | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%, 60%, 80%, 100%} |
50 | | Borderline + Rare | 200,000 | 5 | 2 | {1, 10, 100} | {20%, 40%} |
51 |
52 | ### Experiment 4: Robustness to noise drift
53 | Use any algorithm in `moa.classifiers` and imbalanced generator in `moa.streams.generators.imbalanced`. The parameter `-f` controls the percentage of features with noise. The parameter `-m` controls the proportion of the minority vs majority class, e.g. `-m 0.01` reflects an imbalance ratio of 100. Generate drifting noise and imbalance ratios by chaining `ConceptDriftStream` streams with different imbalance ratios, percentages of features with noise, and noise seed `-r`. The parameter `-p` controls the position of the drift and `-w` the width of the drift (sudden vs gradual). The example shows a sequence of drifting noise to other features and increasing then decreasing imbalance ratio ({5, 10, 20, 100, 20, 10, 5}).
54 | ```
55 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCImbalancedPerformanceEvaluator)" -s "(ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -f (AddNoiseFilterFeatures -r 1 -a 0.99 -f 0.40)) -r 1 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -f (AddNoiseFilterFeatures -r 2 -a 0.99 -f 0.40)) -r 2 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -f (AddNoiseFilterFeatures -r 3 -a 0.99 -f 0.40)) -r 3 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -f (AddNoiseFilterFeatures -r 4 -a 0.99 -f 0.40)) -r 4 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -f (AddNoiseFilterFeatures -r 5 -a 0.99 -f 0.40)) -r 5 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -f (AddNoiseFilterFeatures -r 6 -a 0.99 -f 0.40)) -r 6 -d (ConceptDriftStream -s (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -f (AddNoiseFilterFeatures -r 7 -a 0.99 -f 0.40)) -r 7 -d (FilteredStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -f (AddNoiseFilterFeatures -r 8 -a 0.99 -f 0.40)) -r 8 -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1) -p 125000 -w 1)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -i 1000000 -f 500 -d results.csv
56 | ```
57 |
58 | | Generator | Instances | Features | Classes | Drifting imbalance ratios | Concept Drift | Percentage of features with noise
59 | | -------- | ---: | ---: | ---: |:-: |:-: | :-- |
60 | | Agrawal | 1,000,000 | 9 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
61 | | AssetNegotiation | 1,000,000 | 5 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
62 | | RandomRBF | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
63 | | SEA | 1,000,000 | 3 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
64 | | Sine | 1,000,000 | 4 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
65 | | Hyperplane | 1,000,000 | 10 | 2 | {5, 10, 20, 100, 20, 10, 5} | 8 drifts {sudden, gradual} | {10%, 20%, 30%, 40%}
66 |
67 | ### Experiment 5: Datasets
68 | Use any algorithm in `moa.classifiers` and dataset from UCI / KEEL dataset repositories.
69 | ```
70 | java -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar:MOA-dependencies.jar moa.DoTask EvaluateInterleavedTestThenTrain -e "(WindowAUCMultiClassImbalancedPerformanceEvaluator)" -s "(ArffFileStream -f dataset.arff)" -l "(moa.classifiers.meta.imbalanced.ROSE)" -f 500 -d results.csv
71 | ```
72 |
73 | | Dataset | Instances | Features | Classes |
74 | | -------- | ---: | ---: | ---: |
75 | | adult | 45,222 | 14 | 2 |
76 | | airlines | 539,383 | 7 | 2 |
77 | | bridges | 1,000,000 | 12 | 6 |
78 | | census | 299,284 | 41 | 2 |
79 | | coil2000 | 9,822 | 85 | 2 |
80 | | connect-4 | 67,557 | 42 | 3 |
81 | | covtype | 581,012 | 54 | 7 |
82 | | dj30 | 138,166 | 7 | 30 |
83 | | electricity | 45,312 | 8 | 2 |
84 | | fars | 100,968 | 29 | 8 |
85 | | gas-sensor | 13,910 | 128 | 6 |
86 | | gmsc | 150,000 | 10 | 2 |
87 | | intel-lab | 2,313,153 | 5 | 58 |
88 | | kddcup | 4,898,431 | 41 | 23 |
89 | | kr-vs-k | 28,056 | 6 | 18 |
90 | | letter | 20,000 | 16 | 26 |
91 | | magic | 19,020 | 10 | 2 |
92 | | nomao | 34,465 | 118 | 2 |
93 | | penbased | 10,992 | 16 | 10 |
94 | | poker | 829,201 | 10 | 10 |
95 | | powersupply | 29,928 | 2 | 24 |
96 | | shuttle | 57,999 | 9 | 7 |
97 | | thyroid | 7,200 | 21 | 3 |
98 | | zoo | 1,000,000 | 17 | 7 |
99 |
100 | ## Citation
101 | ```
102 | @article{cano2022rose,
103 | title={{ROSE: robust online self-adjusting ensemble for continual learning on imbalanced drifting data streams}},
104 | author={Cano, Alberto and Krawczyk, Bartosz},
105 | journal={Machine Learning},
106 | volume={111},
107 | number={7},
108 | pages={2561--2599},
109 | year={2022}
110 | }
111 | ```
112 |
--------------------------------------------------------------------------------
/ROSE-1.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/ROSE-1.0.jar
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
4 | 4.0.0
5 |
6 | edu.vcu.acano
7 | ROSE
8 | 1.0
9 |
10 | ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams
11 | https://github.com/canoalberto/ROSE
12 |
13 |
14 | Virginia Commonwealth University, Richmond, Virginia, USA
15 | http://www.vcu.edu/
16 |
17 |
18 |
19 |
20 | GNU General Public License 3.0
21 | http://www.gnu.org/licenses/gpl-3.0.txt
22 | repo
23 |
24 |
25 |
26 |
27 |
28 | canoalberto
29 | Alberto Cano
30 | acano@vcu.edu
31 |
32 |
33 |
34 |
35 | UTF-8
36 | 1.11
37 | 1.11
38 |
39 |
40 |
41 |
42 | junit
43 | junit
44 | 4.13.2
45 | test
46 |
47 |
48 | nz.ac.waikato.cms.moa
49 | moa
50 | 2021.07.0
51 |
52 |
53 | nz.ac.waikato.cms.weka
54 | weka-dev
55 | 3.9.5
56 |
57 |
58 | org.apache.commons
59 | commons-math3
60 | 3.6.1
61 |
62 |
63 | gov.nist.math
64 | jama
65 | 1.0.3
66 |
67 |
68 |
69 |
70 |
71 |
72 | maven-assembly-plugin
73 |
74 |
75 |
76 | moa.DoTask
77 |
78 |
79 |
80 | jar-with-dependencies
81 |
82 |
83 |
84 |
85 | make-assembly
86 | package
87 |
88 | single
89 |
90 |
91 |
92 |
93 |
94 | org.apache.maven.plugins
95 | maven-compiler-plugin
96 | 3.8.1
97 |
98 | 11
99 | 11
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/sizeofag-1.0.4.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/canoalberto/ROSE/c86c8e111c8f823d394639a8a3a0e3d3598ad5a7/sizeofag-1.0.4.jar
--------------------------------------------------------------------------------
/src/main/java/experiments/Collect_Results.java:
--------------------------------------------------------------------------------
1 | package experiments;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.File;
5 | import java.io.FileReader;
6 |
7 | public class Collect_Results {
8 |
9 | public static void main(String[] args) throws Exception {
10 |
11 | String resultsPath = "results_datasets";
12 |
13 | // Replace with the collection of files (datasets / generators) to parse
14 | String[] files = new String[] {
15 | "thyroid",
16 | "coil2000",
17 | "penbased",
18 | "gas-sensor",
19 | "magic",
20 | "letter",
21 | "kr-vs-k",
22 | "powersupply",
23 | "adult",
24 | "electricity",
25 | "shuttle",
26 | "connect-4",
27 | "fars",
28 | "dj30",
29 | "census",
30 | "airlines",
31 | "covtype",
32 | "poker",
33 | "bridges",
34 | "zoo",
35 | "IntelLabSensors",
36 | "kddcup",
37 | "GMSC",
38 | "nomao",
39 | };
40 |
41 | String[] algorithms = new String[] {
42 | "moa.classifiers.meta.imbalanced.ROSE",
43 | "moa.classifiers.meta.KUE",
44 | "moa.classifiers.meta.AccuracyWeightedEnsemble",
45 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1",
46 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2",
47 | "moa.classifiers.meta.DynamicWeightedMajority",
48 | "moa.classifiers.meta.SAE2",
49 | "moa.classifiers.meta.DACC",
50 | "moa.classifiers.meta.ADACC",
51 | "moa.classifiers.meta.AdaptiveRandomForest",
52 | "moa.classifiers.meta.ADOB",
53 | "moa.classifiers.meta.BOLE",
54 | "moa.classifiers.meta.GOOWE",
55 | "moa.classifiers.meta.HeterogeneousEnsembleBlast",
56 | "moa.classifiers.meta.LeveragingBag",
57 | "moa.classifiers.meta.OCBoost",
58 | "moa.classifiers.meta.OzaBag",
59 | "moa.classifiers.meta.OzaBagAdwin",
60 | "moa.classifiers.meta.OzaBagASHT",
61 | "moa.classifiers.meta.OzaBoost",
62 | "moa.classifiers.meta.OzaBoostAdwin",
63 | "moa.classifiers.meta.StreamingRandomPatches",
64 | "moa.classifiers.meta.UOB",
65 | "moa.classifiers.meta.OOB",
66 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging",
67 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging",
68 | "moa.classifiers.meta.imbalanced.CSMOTE",
69 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost",
70 | "moa.classifiers.meta.imbalanced.OnlineAdaC2",
71 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost",
72 | "moa.classifiers.meta.imbalanced.RebalanceStream",
73 | };
74 |
75 | String[] algorithmsFilename = new String[algorithms.length];
76 |
77 | for(int alg = 0; alg < algorithms.length; alg++) {
78 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", "");
79 | }
80 |
81 | metric("Accuracy", "averaged", resultsPath, files, algorithmsFilename);
82 | metric("Kappa", "averaged", resultsPath,files, algorithmsFilename);
83 | metric("AUC", "averaged", resultsPath, files, algorithmsFilename); // do not use for multi-class datasets
84 | metric("PMAUC", "averaged", resultsPath, files, algorithmsFilename); // use for multi-class datasets
85 | }
86 |
87 | private static void metric(String metricName, String outcome, String resultsPath, String[] files, String[] algorithms) throws Exception {
88 |
89 | System.out.print(metricName + "\t");
90 |
91 | for(String algorithm : algorithms)
92 | System.out.print(algorithm + "\t");
93 | System.out.println("");
94 |
95 | for(String file : files)
96 | {
97 | System.out.print(file + "\t");
98 |
99 | for(int alg = 0; alg < algorithms.length; alg++)
100 | {
101 | int count = 0;
102 | double sum = 0;
103 | double lastValue = 0;
104 |
105 | String filename = resultsPath + "/" + algorithms[alg] + "-" + file + ".csv";
106 |
107 | if(new File(filename).exists())
108 | {
109 | BufferedReader br = new BufferedReader(new FileReader(new File(filename)));
110 |
111 | String line;
112 | line = br.readLine(); // header line
113 |
114 | String[] columns = null;
115 |
116 | try {
117 | columns = line.split(",");
118 | } catch (Exception e) {
119 | e.printStackTrace();
120 | System.exit(-1);
121 | }
122 |
123 | int index = -1;
124 | for(int i = 0; i < columns.length; i++)
125 | if(columns[i].equals(metricName))
126 | index = i;
127 |
128 | while((line = br.readLine()) != null)
129 | {
130 | try {
131 | if(!line.split(",")[index].equals("?"))
132 | {
133 | lastValue = Double.parseDouble(line.split(",")[index]);
134 | sum += lastValue;
135 | count++;
136 | }
137 | } catch (Exception e) {
138 | lastValue = 0;
139 | sum += lastValue;
140 | count++;
141 | }
142 | }
143 |
144 | br.close();
145 | }
146 |
147 | if(outcome.equalsIgnoreCase("averaged"))
148 | System.out.print((sum/count) + "\t");
149 | else
150 | System.out.print(lastValue + "\t");
151 | }
152 |
153 | System.out.println("");
154 | }
155 | }
156 | }
--------------------------------------------------------------------------------
/src/main/java/experiments/Datasets.java:
--------------------------------------------------------------------------------
1 | package experiments;
2 |
3 | import org.apache.commons.lang3.SystemUtils;
4 |
5 | public class Datasets {
6 |
7 | public static void main(String[] args) throws Exception {
8 |
9 | // Download datasets from https://people.vcu.edu/~acano/ROSE/datasets.zip
10 |
11 | String[] datasets = new String[] {
12 | "thyroid",
13 | "coil2000",
14 | "penbased",
15 | "gas-sensor",
16 | "magic",
17 | "letter",
18 | "kr-vs-k",
19 | "powersupply",
20 | "adult",
21 | "electricity",
22 | "shuttle",
23 | "connect-4",
24 | "fars",
25 | "dj30",
26 | "census",
27 | "airlines",
28 | "covtype",
29 | "poker",
30 | "bridges",
31 | "zoo",
32 | "IntelLabSensors",
33 | "kddcup",
34 | "GMSC",
35 | "nomao",
36 | };
37 |
38 | String[] algorithms = new String[] {
39 | "moa.classifiers.meta.imbalanced.ROSE",
40 | "moa.classifiers.meta.KUE",
41 | "moa.classifiers.meta.AccuracyWeightedEnsemble",
42 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1",
43 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2",
44 | "moa.classifiers.meta.DynamicWeightedMajority",
45 | "moa.classifiers.meta.SAE2",
46 | "moa.classifiers.meta.DACC",
47 | "moa.classifiers.meta.ADACC",
48 | "moa.classifiers.meta.AdaptiveRandomForest",
49 | "moa.classifiers.meta.ADOB",
50 | "moa.classifiers.meta.BOLE",
51 | "moa.classifiers.meta.GOOWE",
52 | "moa.classifiers.meta.HeterogeneousEnsembleBlast",
53 | "moa.classifiers.meta.LeveragingBag",
54 | "moa.classifiers.meta.OCBoost",
55 | "moa.classifiers.meta.OzaBag",
56 | "moa.classifiers.meta.OzaBagAdwin",
57 | "moa.classifiers.meta.OzaBagASHT",
58 | "moa.classifiers.meta.OzaBoost",
59 | "moa.classifiers.meta.OzaBoostAdwin",
60 | "moa.classifiers.meta.StreamingRandomPatches",
61 | "moa.classifiers.meta.UOB",
62 | "moa.classifiers.meta.OOB",
63 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging",
64 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging",
65 | "moa.classifiers.meta.imbalanced.CSMOTE",
66 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost",
67 | "moa.classifiers.meta.imbalanced.OnlineAdaC2",
68 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost",
69 | "moa.classifiers.meta.imbalanced.RebalanceStream",
70 | };
71 |
72 | String[] algorithmsFilename = new String[algorithms.length];
73 |
74 | for(int alg = 0; alg < algorithms.length; alg++) {
75 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", "");
76 | }
77 |
78 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";";
79 |
80 | for(int dat = 0; dat < datasets.length; dat++)
81 | {
82 | for(int alg = 0; alg < algorithms.length; alg++)
83 | {
84 | // Replace evaluator with WindowAUCMultiClassImbalancedPerformanceEvaluator for multi-class datasets
85 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar "
86 | + "moa.DoTask EvaluateInterleavedTestThenTrain"
87 | + " -e \"(WindowAUCMultiClassImbalancedPerformanceEvaluator)\""
88 | + " -s \"(ArffFileStream -f datasets/" + datasets[dat] + ".arff)\""
89 | + " -l \"(" + algorithms[alg] + ")\""
90 | + " -f 500"
91 | + " -d results_datasets/" + algorithmsFilename[alg] + "-" + datasets[dat] + ".csv");
92 | }
93 | }
94 | }
95 | }
--------------------------------------------------------------------------------
/src/main/java/experiments/Drifting_Imbalance_Ratio.java:
--------------------------------------------------------------------------------
1 | package experiments;
2 |
3 | import org.apache.commons.lang3.SystemUtils;
4 |
5 | public class Drifting_Imbalance_Ratio {
6 |
7 | public static void main(String[] args) throws Exception {
8 |
9 | String[] generators = new String[] {
10 | // Sudden drift
11 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 "
12 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 "
13 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 "
14 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 "
15 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 "
16 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 "
17 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 "
18 | + "-d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 "
19 | + "-p 125000 -w 1) "
20 | + "-p 125000 -w 1) "
21 | + "-p 125000 -w 1) "
22 | + "-p 125000 -w 1) "
23 | + "-p 125000 -w 1) "
24 | + "-p 125000 -w 1) "
25 | + "-p 125000 -w 1",
26 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 1 -m 0.2) -r 1 "
27 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 2 -f 1 -m 0.1) -r 2 "
28 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 3 -f 1 -m 0.05) -r 3 "
29 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 4 -f 1 -m 0.01) -r 4 "
30 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 5 -f 1 -m 0.01) -r 5 "
31 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 6 -f 1 -m 0.05) -r 6 "
32 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 7 -f 1 -m 0.1) -r 7 "
33 | + "-d (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 8 -f 1 -m 0.2) -r 8 "
34 | + "-p 125000 -w 1) "
35 | + "-p 125000 -w 1) "
36 | + "-p 125000 -w 1) "
37 | + "-p 125000 -w 1) "
38 | + "-p 125000 -w 1) "
39 | + "-p 125000 -w 1) "
40 | + "-p 125000 -w 1",
41 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -m 0.2) -r 1 "
42 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 2 -a 10 -c 2 -m 0.1) -r 2 "
43 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 3 -a 10 -c 2 -m 0.05) -r 3 "
44 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 4 -a 10 -c 2 -m 0.01) -r 4 "
45 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 5 -a 10 -c 2 -m 0.01) -r 5 "
46 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 6 -a 10 -c 2 -m 0.05) -r 6 "
47 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 7 -a 10 -c 2 -m 0.1) -r 7 "
48 | + "-d (moa.streams.generators.imbalanced.HyperplaneGenerator -i 8 -a 10 -c 2 -m 0.2) -r 8 "
49 | + "-p 125000 -w 1) "
50 | + "-p 125000 -w 1) "
51 | + "-p 125000 -w 1) "
52 | + "-p 125000 -w 1) "
53 | + "-p 125000 -w 1) "
54 | + "-p 125000 -w 1) "
55 | + "-p 125000 -w 1",
56 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 10 -c 2 -m 0.2) -r 1 "
57 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 2 -r 2 -a 10 -c 2 -m 0.1) -r 2 "
58 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 3 -r 3 -a 10 -c 2 -m 0.05) -r 3 "
59 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 4 -r 4 -a 10 -c 2 -m 0.01) -r 4 "
60 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 5 -r 5 -a 10 -c 2 -m 0.01) -r 5 "
61 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 6 -r 6 -a 10 -c 2 -m 0.05) -r 6 "
62 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 7 -r 7 -a 10 -c 2 -m 0.1) -r 7 "
63 | + "-d (moa.streams.generators.imbalanced.RandomRBFGenerator -i 8 -r 8 -a 10 -c 2 -m 0.2) -r 8 "
64 | + "-p 125000 -w 1) "
65 | + "-p 125000 -w 1) "
66 | + "-p 125000 -w 1) "
67 | + "-p 125000 -w 1) "
68 | + "-p 125000 -w 1) "
69 | + "-p 125000 -w 1) "
70 | + "-p 125000 -w 1",
71 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2) -r 1 "
72 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 2 -f 1 -m 0.1) -r 2 "
73 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 3 -f 1 -m 0.05) -r 3 "
74 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 4 -f 1 -m 0.01) -r 4 "
75 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 5 -f 1 -m 0.01) -r 5 "
76 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 6 -f 1 -m 0.05) -r 6 "
77 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 7 -f 1 -m 0.1) -r 7 "
78 | + "-d (moa.streams.generators.imbalanced.SEAGenerator -i 8 -f 1 -m 0.2) -r 8 "
79 | + "-p 125000 -w 1) "
80 | + "-p 125000 -w 1) "
81 | + "-p 125000 -w 1) "
82 | + "-p 125000 -w 1) "
83 | + "-p 125000 -w 1) "
84 | + "-p 125000 -w 1) "
85 | + "-p 125000 -w 1",
86 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2) -r 1 "
87 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 2 -f 1 -m 0.1) -r 2 "
88 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 3 -f 1 -m 0.05) -r 3 "
89 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 4 -f 1 -m 0.01) -r 4 "
90 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 5 -f 1 -m 0.01) -r 5 "
91 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 6 -f 1 -m 0.05) -r 6 "
92 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 7 -f 1 -m 0.1) -r 7 "
93 | + "-d (moa.streams.generators.imbalanced.SineGenerator -i 8 -f 1 -m 0.2) -r 8 "
94 | + "-p 125000 -w 1) "
95 | + "-p 125000 -w 1) "
96 | + "-p 125000 -w 1) "
97 | + "-p 125000 -w 1) "
98 | + "-p 125000 -w 1) "
99 | + "-p 125000 -w 1) "
100 | + "-p 125000 -w 1",
101 | // Gradual drift
102 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2) -r 1 "
103 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 2 -f 2 -m 0.1) -r 2 "
104 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 3 -f 2 -m 0.05) -r 3 "
105 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 4 -f 2 -m 0.01) -r 4 "
106 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 5 -f 2 -m 0.01) -r 5 "
107 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 6 -f 2 -m 0.05) -r 6 "
108 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AgrawalGenerator -i 7 -f 2 -m 0.1) -r 7 "
109 | + "-d (moa.streams.generators.imbalanced.AgrawalGenerator -i 8 -f 2 -m 0.2) -r 8 "
110 | + "-p 125000 -w 50000) "
111 | + "-p 125000 -w 50000) "
112 | + "-p 125000 -w 50000) "
113 | + "-p 125000 -w 50000) "
114 | + "-p 125000 -w 50000) "
115 | + "-p 125000 -w 50000) "
116 | + "-p 125000 -w 50000",
117 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 1 -m 0.2) -r 1 "
118 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 2 -f 1 -m 0.1) -r 2 "
119 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 3 -f 1 -m 0.05) -r 3 "
120 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 4 -f 1 -m 0.01) -r 4 "
121 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 5 -f 1 -m 0.01) -r 5 "
122 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 6 -f 1 -m 0.05) -r 6 "
123 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 7 -f 1 -m 0.1) -r 7 "
124 | + "-d (moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 8 -f 1 -m 0.2) -r 8 "
125 | + "-p 125000 -w 50000) "
126 | + "-p 125000 -w 50000) "
127 | + "-p 125000 -w 50000) "
128 | + "-p 125000 -w 50000) "
129 | + "-p 125000 -w 50000) "
130 | + "-p 125000 -w 50000) "
131 | + "-p 125000 -w 50000",
132 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -m 0.2) -r 1 "
133 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 2 -a 10 -c 2 -m 0.1) -r 2 "
134 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 3 -a 10 -c 2 -m 0.05) -r 3 "
135 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 4 -a 10 -c 2 -m 0.01) -r 4 "
136 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 5 -a 10 -c 2 -m 0.01) -r 5 "
137 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 6 -a 10 -c 2 -m 0.05) -r 6 "
138 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.HyperplaneGenerator -i 7 -a 10 -c 2 -m 0.1) -r 7 "
139 | + "-d (moa.streams.generators.imbalanced.HyperplaneGenerator -i 8 -a 10 -c 2 -m 0.2) -r 8 "
140 | + "-p 125000 -w 50000) "
141 | + "-p 125000 -w 50000) "
142 | + "-p 125000 -w 50000) "
143 | + "-p 125000 -w 50000) "
144 | + "-p 125000 -w 50000) "
145 | + "-p 125000 -w 50000) "
146 | + "-p 125000 -w 50000",
147 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 10 -c 2 -m 0.2) -r 1 "
148 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 2 -r 2 -a 10 -c 2 -m 0.1) -r 2 "
149 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 3 -r 3 -a 10 -c 2 -m 0.05) -r 3 "
150 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 4 -r 4 -a 10 -c 2 -m 0.01) -r 4 "
151 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 5 -r 5 -a 10 -c 2 -m 0.01) -r 5 "
152 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 6 -r 6 -a 10 -c 2 -m 0.05) -r 6 "
153 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.RandomRBFGenerator -i 7 -r 7 -a 10 -c 2 -m 0.1) -r 7 "
154 | + "-d (moa.streams.generators.imbalanced.RandomRBFGenerator -i 8 -r 8 -a 10 -c 2 -m 0.2) -r 8 "
155 | + "-p 125000 -w 50000) "
156 | + "-p 125000 -w 50000) "
157 | + "-p 125000 -w 50000) "
158 | + "-p 125000 -w 50000) "
159 | + "-p 125000 -w 50000) "
160 | + "-p 125000 -w 50000) "
161 | + "-p 125000 -w 50000",
162 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2) -r 1 "
163 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 2 -f 1 -m 0.1) -r 2 "
164 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 3 -f 1 -m 0.05) -r 3 "
165 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 4 -f 1 -m 0.01) -r 4 "
166 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 5 -f 1 -m 0.01) -r 5 "
167 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 6 -f 1 -m 0.05) -r 6 "
168 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SEAGenerator -i 7 -f 1 -m 0.1) -r 7 "
169 | + "-d (moa.streams.generators.imbalanced.SEAGenerator -i 8 -f 1 -m 0.2) -r 8 "
170 | + "-p 125000 -w 50000) "
171 | + "-p 125000 -w 50000) "
172 | + "-p 125000 -w 50000) "
173 | + "-p 125000 -w 50000) "
174 | + "-p 125000 -w 50000) "
175 | + "-p 125000 -w 50000) "
176 | + "-p 125000 -w 50000",
177 | "ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2) -r 1 "
178 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 2 -f 1 -m 0.1) -r 2 "
179 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 3 -f 1 -m 0.05) -r 3 "
180 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 4 -f 1 -m 0.01) -r 4 "
181 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 5 -f 1 -m 0.01) -r 5 "
182 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 6 -f 1 -m 0.05) -r 6 "
183 | + "-d (ConceptDriftStream -s (moa.streams.generators.imbalanced.SineGenerator -i 7 -f 1 -m 0.1) -r 7 "
184 | + "-d (moa.streams.generators.imbalanced.SineGenerator -i 8 -f 1 -m 0.2) -r 8 "
185 | + "-p 125000 -w 50000) "
186 | + "-p 125000 -w 50000) "
187 | + "-p 125000 -w 50000) "
188 | + "-p 125000 -w 50000) "
189 | + "-p 125000 -w 50000) "
190 | + "-p 125000 -w 50000) "
191 | + "-p 125000 -w 50000",
192 | };
193 |
194 | String[] generatorsFilename = new String[] {
195 | "Sudden-AgrawalGenerator-IncreasingDecreasingIR",
196 | "Sudden-AssetNegotiationGenerator-IncreasingDecreasingIR",
197 | "Sudden-HyperplaneGenerator-IncreasingDecreasingIR",
198 | "Sudden-RandomRBFGenerator-IncreasingDecreasingIR",
199 | "Sudden-SEAGenerator-IncreasingDecreasingIR",
200 | "Sudden-SineGenerator-IncreasingDecreasingIR",
201 | "Gradual-AgrawalGenerator-IncreasingDecreasingIR",
202 | "Gradual-AssetNegotiationGenerator-IncreasingDecreasingIR",
203 | "Gradual-HyperplaneGenerator-IncreasingDecreasingIR",
204 | "Gradual-RandomRBFGenerator-IncreasingDecreasingIR",
205 | "Gradual-SEAGenerator-IncreasingDecreasingIR",
206 | "Gradual-SineGenerator-IncreasingDecreasingIR",
207 | };
208 |
209 | String[] algorithms = new String[] {
210 | "moa.classifiers.meta.imbalanced.ROSE",
211 | "moa.classifiers.meta.KUE",
212 | "moa.classifiers.meta.AccuracyWeightedEnsemble",
213 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1",
214 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2",
215 | "moa.classifiers.meta.DynamicWeightedMajority",
216 | "moa.classifiers.meta.SAE2",
217 | "moa.classifiers.meta.DACC",
218 | "moa.classifiers.meta.ADACC",
219 | "moa.classifiers.meta.AdaptiveRandomForest",
220 | "moa.classifiers.meta.ADOB",
221 | "moa.classifiers.meta.BOLE",
222 | "moa.classifiers.meta.GOOWE",
223 | "moa.classifiers.meta.HeterogeneousEnsembleBlast",
224 | "moa.classifiers.meta.LeveragingBag",
225 | "moa.classifiers.meta.OCBoost",
226 | "moa.classifiers.meta.OzaBag",
227 | "moa.classifiers.meta.OzaBagAdwin",
228 | "moa.classifiers.meta.OzaBagASHT",
229 | "moa.classifiers.meta.OzaBoost",
230 | "moa.classifiers.meta.OzaBoostAdwin",
231 | "moa.classifiers.meta.StreamingRandomPatches",
232 | "moa.classifiers.meta.UOB",
233 | "moa.classifiers.meta.OOB",
234 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging",
235 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging",
236 | "moa.classifiers.meta.imbalanced.CSMOTE",
237 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost",
238 | "moa.classifiers.meta.imbalanced.OnlineAdaC2",
239 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost",
240 | "moa.classifiers.meta.imbalanced.RebalanceStream",
241 | };
242 |
243 | String[] algorithmsFilename = new String[algorithms.length];
244 |
245 | for(int alg = 0; alg < algorithms.length; alg++) {
246 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", "");
247 | }
248 |
249 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";";
250 |
251 | for(int gen = 0; gen < generators.length; gen++)
252 | {
253 | for(int alg = 0; alg < algorithms.length; alg++)
254 | {
255 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar "
256 | + "moa.DoTask EvaluateInterleavedTestThenTrain"
257 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator)\""
258 | + " -s \"(" + generators[gen] + ")\""
259 | + " -l \"(" + algorithms[alg] + ")\""
260 | + " -f 500"
261 | + " -d results_drifting_IR/" + algorithmsFilename[alg] + "-" + generatorsFilename[gen] + ".csv");
262 | }
263 | }
264 | }
265 | }
--------------------------------------------------------------------------------
/src/main/java/experiments/Instance_Level_Difficulties.java:
--------------------------------------------------------------------------------
1 | package experiments;
2 |
3 | import org.apache.commons.lang3.SystemUtils;
4 |
5 | public class Instance_Level_Difficulties {
6 |
7 | public static void main(String[] args) throws Exception {
8 |
9 | // Download datasets from https://people.vcu.edu/~acano/ROSE/datasets-instance-level.zip
10 |
11 | // Splitting into 5 clusters and evaluating the impact of percentage of borderline and rare instances
12 |
13 | String[] datasets = new String[] {
14 | // IR 1
15 | "Split5",
16 | "Split5+Rare20",
17 | "Split5+Rare40",
18 | "Split5+Rare60",
19 | "Split5+Rare80",
20 | "Split5+Rare100",
21 | "Split5+Borderline20",
22 | "Split5+Borderline40",
23 | "Split5+Borderline60",
24 | "Split5+Borderline80",
25 | "Split5+Borderline100",
26 | "Split5+Borderline20+Rare20",
27 | "Split5+Borderline40+Rare40",
28 |
29 | // IR 10
30 | "Split5+Im10",
31 | "Split5+Im10+Rare20",
32 | "Split5+Im10+Rare40",
33 | "Split5+Im10+Rare60",
34 | "Split5+Im10+Rare80",
35 | "Split5+Im10+Rare100",
36 | "Split5+Im10+Borderline20",
37 | "Split5+Im10+Borderline40",
38 | "Split5+Im10+Borderline60",
39 | "Split5+Im10+Borderline80",
40 | "Split5+Im10+Borderline100",
41 | "Split5+Im10+Borderline20+Rare20",
42 | "Split5+Im10+Borderline40+Rare40",
43 |
44 | // IR 100
45 | "Split5+Im1",
46 | "Split5+Im1+Rare20",
47 | "Split5+Im1+Rare40",
48 | "Split5+Im1+Rare60",
49 | "Split5+Im1+Rare80",
50 | "Split5+Im1+Rare100",
51 | "Split5+Im1+Borderline20",
52 | "Split5+Im1+Borderline40",
53 | "Split5+Im1+Borderline60",
54 | "Split5+Im1+Borderline80",
55 | "Split5+Im1+Borderline100",
56 | "Split5+Im1+Borderline20+Rare20",
57 | "Split5+Im1+Borderline40+Rare40",
58 | };
59 |
60 | String[] algorithms = new String[] {
61 | "moa.classifiers.meta.imbalanced.ROSE",
62 | "moa.classifiers.meta.KUE",
63 | "moa.classifiers.meta.AccuracyWeightedEnsemble",
64 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1",
65 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2",
66 | "moa.classifiers.meta.DynamicWeightedMajority",
67 | "moa.classifiers.meta.SAE2",
68 | "moa.classifiers.meta.DACC",
69 | "moa.classifiers.meta.ADACC",
70 | "moa.classifiers.meta.AdaptiveRandomForest",
71 | "moa.classifiers.meta.ADOB",
72 | "moa.classifiers.meta.BOLE",
73 | "moa.classifiers.meta.GOOWE",
74 | "moa.classifiers.meta.HeterogeneousEnsembleBlast",
75 | "moa.classifiers.meta.LeveragingBag",
76 | "moa.classifiers.meta.OCBoost",
77 | "moa.classifiers.meta.OzaBag",
78 | "moa.classifiers.meta.OzaBagAdwin",
79 | "moa.classifiers.meta.OzaBagASHT",
80 | "moa.classifiers.meta.OzaBoost",
81 | "moa.classifiers.meta.OzaBoostAdwin",
82 | "moa.classifiers.meta.StreamingRandomPatches",
83 | "moa.classifiers.meta.UOB",
84 | "moa.classifiers.meta.OOB",
85 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging",
86 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging",
87 | "moa.classifiers.meta.imbalanced.CSMOTE",
88 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost",
89 | "moa.classifiers.meta.imbalanced.OnlineAdaC2",
90 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost",
91 | "moa.classifiers.meta.imbalanced.RebalanceStream",
92 | };
93 |
94 | String[] algorithmsFilename = new String[algorithms.length];
95 |
96 | for(int alg = 0; alg < algorithms.length; alg++) {
97 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", "");
98 | }
99 |
100 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";";
101 |
102 | for(int dat = 0; dat < datasets.length; dat++)
103 | {
104 | for(int alg = 0; alg < algorithms.length; alg++)
105 | {
106 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar "
107 | + "moa.DoTask EvaluateInterleavedTestThenTrain"
108 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator -w 100)\""
109 | + " -s \"(ArffFileStream -f datasets-instance-level-difficulties/" + datasets[dat] + ".arff)\""
110 | + " -l \"(" + algorithms[alg] + ")\""
111 | + " -f 500"
112 | + " -d results_instance_level_difficulties/" + algorithmsFilename[alg] + "-" + datasets[dat] + ".csv");
113 | }
114 | }
115 | }
116 | }
--------------------------------------------------------------------------------
/src/main/java/experiments/Static_Imbalance_Ratio.java:
--------------------------------------------------------------------------------
1 | package experiments;
2 |
3 | import org.apache.commons.lang3.SystemUtils;
4 |
5 | public class Static_Imbalance_Ratio {
6 |
7 | public static void main(String[] args) throws Exception {
8 |
9 | String[] generators = new String[] {
10 | // IR 100
11 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.01",
12 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.01",
13 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.01",
14 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.01",
15 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.01",
16 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.01",
17 |
18 | // IR 50
19 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.02",
20 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.02",
21 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.02",
22 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.02",
23 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.02",
24 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.02",
25 |
26 | // IR 20
27 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.05",
28 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.05",
29 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.05",
30 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.05",
31 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.05",
32 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.05",
33 |
34 | // IR 10
35 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.1",
36 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.1",
37 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.1",
38 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.1",
39 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.1",
40 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.1",
41 |
42 | // IR 5
43 | "moa.streams.generators.imbalanced.AgrawalGenerator -i 1 -f 2 -m 0.2",
44 | "moa.streams.generators.imbalanced.AssetNegotiationGenerator -i 1 -f 3 -m 0.2",
45 | "moa.streams.generators.imbalanced.RandomRBFGenerator -i 1 -r 1 -a 50 -c 2 -m 0.2",
46 | "moa.streams.generators.imbalanced.SEAGenerator -i 1 -f 1 -m 0.2",
47 | "moa.streams.generators.imbalanced.SineGenerator -i 1 -f 1 -m 0.2",
48 | "moa.streams.generators.imbalanced.HyperplaneGenerator -i 1 -a 10 -c 2 -k 2 -t 0.05 -m 0.2",
49 | };
50 |
51 | String[] algorithms = new String[] {
52 | "moa.classifiers.meta.imbalanced.ROSE",
53 | "moa.classifiers.meta.KUE",
54 | "moa.classifiers.meta.AccuracyWeightedEnsemble",
55 | "moa.classifiers.meta.AccuracyUpdatedEnsemble1",
56 | "moa.classifiers.meta.AccuracyUpdatedEnsemble2",
57 | "moa.classifiers.meta.DynamicWeightedMajority",
58 | "moa.classifiers.meta.SAE2",
59 | "moa.classifiers.meta.DACC",
60 | "moa.classifiers.meta.ADACC",
61 | "moa.classifiers.meta.AdaptiveRandomForest",
62 | "moa.classifiers.meta.ADOB",
63 | "moa.classifiers.meta.BOLE",
64 | "moa.classifiers.meta.GOOWE",
65 | "moa.classifiers.meta.HeterogeneousEnsembleBlast",
66 | "moa.classifiers.meta.LeveragingBag",
67 | "moa.classifiers.meta.OCBoost",
68 | "moa.classifiers.meta.OzaBag",
69 | "moa.classifiers.meta.OzaBagAdwin",
70 | "moa.classifiers.meta.OzaBagASHT",
71 | "moa.classifiers.meta.OzaBoost",
72 | "moa.classifiers.meta.OzaBoostAdwin",
73 | "moa.classifiers.meta.StreamingRandomPatches",
74 | "moa.classifiers.meta.UOB",
75 | "moa.classifiers.meta.OOB",
76 | "moa.classifiers.meta.imbalanced.OnlineSMOTEBagging",
77 | "moa.classifiers.meta.imbalanced.OnlineUnderOverBagging",
78 | "moa.classifiers.meta.imbalanced.CSMOTE",
79 | "moa.classifiers.meta.imbalanced.OnlineAdaBoost",
80 | "moa.classifiers.meta.imbalanced.OnlineAdaC2",
81 | "moa.classifiers.meta.imbalanced.OnlineRUSBoost",
82 | "moa.classifiers.meta.imbalanced.RebalanceStream",
83 | };
84 |
85 | String[] algorithmsFilename = new String[algorithms.length];
86 | String[] generatorsFilename = new String[generators.length];
87 |
88 | for(int alg = 0; alg < algorithms.length; alg++) {
89 | algorithmsFilename[alg] = algorithms[alg].replaceAll("moa.classifiers.meta.", "").replaceAll("imbalanced.", "");
90 | }
91 |
92 | for(int gen = 0; gen < generators.length; gen++) {
93 | generatorsFilename[gen] = generators[gen].replaceAll("moa.streams.generators.imbalanced.", "").replaceAll(" ", "");
94 | }
95 |
96 | String classpathSeparator = SystemUtils.IS_OS_UNIX ? ":" : ";";
97 |
98 | for(int gen = 0; gen < generators.length; gen++)
99 | {
100 | for(int alg = 0; alg < algorithms.length; alg++)
101 | {
102 | System.out.println("java -Xms16g -Xmx1024g -javaagent:sizeofag-1.0.4.jar -cp ROSE-1.0.jar" + classpathSeparator + "MOA-dependencies.jar "
103 | + "moa.DoTask EvaluateInterleavedTestThenTrain"
104 | + " -e \"(WindowAUCImbalancedPerformanceEvaluator)\""
105 | + " -s \"(" + generators[gen] + ")\""
106 | + " -l \"(" + algorithms[alg] + ")\""
107 | + " -f 500"
108 | + " -d results_static_IR/" + algorithmsFilename[alg] + "-" + generatorsFilename[gen] + ".csv");
109 | }
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/src/main/java/moa/classifiers/meta/imbalanced/ROSE.java:
--------------------------------------------------------------------------------
1 | package moa.classifiers.meta.imbalanced;
2 |
3 | import java.util.ArrayList;
4 |
5 | import com.github.javacliparser.FloatOption;
6 | import com.github.javacliparser.IntOption;
7 | import com.yahoo.labs.samoa.instances.Instance;
8 | import com.yahoo.labs.samoa.instances.Instances;
9 | import com.yahoo.labs.samoa.instances.InstancesHeader;
10 |
11 | import moa.AbstractMOAObject;
12 | import moa.capabilities.CapabilitiesHandler;
13 | import moa.capabilities.Capability;
14 | import moa.capabilities.ImmutableCapabilities;
15 | import moa.classifiers.AbstractClassifier;
16 | import moa.classifiers.MultiClassClassifier;
17 | import moa.classifiers.core.driftdetection.ChangeDetector;
18 | import moa.classifiers.trees.RandomSubspaceHT;
19 | import moa.core.DoubleVector;
20 | import moa.core.InstanceExample;
21 | import moa.core.Measurement;
22 | import moa.core.MiscUtils;
23 | import moa.core.Utils;
24 | import moa.evaluation.WindowImbalancedClassificationPerformanceEvaluator;
25 | import moa.options.ClassOption;
26 |
27 | /**
28 | * ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams
29 | *
30 | * @author Alberto Cano
31 | */
32 | public class ROSE extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
33 |
34 | private static final long serialVersionUID = 1L;
35 |
36 | public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "RandomSubspaceHT", RandomSubspaceHT.class, "RandomSubspaceHT");
37 |
38 | public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 10, 1, Integer.MAX_VALUE);
39 |
40 | public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0, 1.0, Float.MAX_VALUE);
41 |
42 | public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5");
43 |
44 | public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4");
45 |
46 | public IntOption featureSpaceOption = new IntOption("feature", 'f', "The feature space to employ {1: uniform distribution with at last 50% of features; 2: normal distribution based on percentageFeaturesMean (70% default)}.", 1, 1, 2);
47 |
48 | public FloatOption percentageFeaturesMean = new FloatOption("percentageFeaturesMean", 'm', "Mean for percentage of featues selected", 0.7, 0, 1);
49 |
50 | public FloatOption theta = new FloatOption("theta", 't', "The time decay factor for class size.", 0.99, 0, 1);
51 |
52 | public IntOption windowSizeOption = new IntOption("window", 'w', "The number of instances in the sliding window.", 500, 1, Integer.MAX_VALUE);
53 |
54 | protected ROSEBaseLearner[] ensemble;
55 | protected ROSEBaseLearner[] ensembleBackground;
56 | protected InstancesTimestap[] instancesClass;
57 |
58 | protected double classSize[];
59 | protected long instancesSeen;
60 | protected long firstWarningOn;
61 | protected boolean warningDetected;
62 | protected WindowImbalancedClassificationPerformanceEvaluator evaluator;
63 |
64 | @Override
65 | public void resetLearningImpl() {
66 | this.warningDetected = true;
67 | this.firstWarningOn = 0;
68 | this.instancesClass = null;
69 | this.ensemble = null;
70 | this.ensembleBackground = null;
71 | this.classSize = null;
72 | this.instancesSeen = 0;
73 | this.evaluator = new WindowImbalancedClassificationPerformanceEvaluator();
74 | }
75 |
76 | @Override
77 | public void trainOnInstanceImpl(Instance instance) {
78 | this.instancesSeen++;
79 |
80 | this.instancesClass[(int) instance.classValue()].add(instance, instancesSeen);
81 |
82 | if(this.ensemble == null) {
83 | initEnsemble(instance);
84 | }
85 |
86 | for (int i=0; i < classSize.length; i++) {
87 | classSize[i] = theta.getValue() * classSize[i] + (1d - theta.getValue()) * ((int) instance.classValue() == i ? 1d:0d);
88 | }
89 |
90 | double lambda = lambdaOption.getValue() + lambdaOption.getValue() * Math.log(classSize[Utils.maxIndex(classSize)] / classSize[(int) instance.classValue()]);
91 |
92 | for (int i = 0; i < this.ensemble.length; i++) {
93 | DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
94 | InstanceExample example = new InstanceExample(instance);
95 | this.ensemble[i].evaluator.addResult(example, vote.getArrayRef());
96 |
97 | int k = MiscUtils.poisson(lambda, this.classifierRandom);
98 | if (k > 0) {
99 | this.ensemble[i].trainOnInstance(instance, k, this.instancesSeen);
100 | }
101 | }
102 |
103 | if(!this.warningDetected) {
104 | for (int i = 0; i < this.ensemble.length; i++) {
105 | if(this.ensemble[i].warningDetected) {
106 | this.warningDetected = true;
107 | this.firstWarningOn = instancesSeen;
108 | break;
109 | }
110 | }
111 |
112 | if(this.warningDetected) {
113 | int numberAttributes = instance.numAttributes() - 1;
114 |
115 | RandomSubspaceHT treeLearner = (RandomSubspaceHT) getPreparedClassOption(this.treeLearnerOption);
116 | treeLearner.resetLearning();
117 |
118 | WindowImbalancedClassificationPerformanceEvaluator classificationEvaluator = new WindowImbalancedClassificationPerformanceEvaluator();
119 |
120 | for(int i = 0; i < this.ensembleSizeOption.getValue(); i++) {
121 |
122 | int subspaceSize = -1;
123 |
124 | if(featureSpaceOption.getValue() == 1) {
125 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2));
126 | } else if(featureSpaceOption.getValue() == 2) {
127 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5);
128 | }
129 |
130 | if (subspaceSize > numberAttributes) {
131 | subspaceSize = numberAttributes;
132 | } else if (subspaceSize <= 0) {
133 | subspaceSize = 1;
134 | }
135 |
136 | treeLearner.subspaceSizeOption.setValue(subspaceSize);
137 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE));
138 |
139 | this.ensembleBackground[i] = new ROSEBaseLearner(
140 | (RandomSubspaceHT) treeLearner.copy(),
141 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(),
142 | this.instancesSeen,
143 | driftDetectionMethodOption,
144 | warningDetectionMethodOption,
145 | false);
146 | }
147 |
148 | int[] indexClass = new int[instance.numClasses()];
149 | Long[] oldestTimestamps = new Long[instance.numClasses()];
150 |
151 | do {
152 | Long oldestTimestamp = Long.MAX_VALUE;
153 | int nextClass = -1;
154 |
155 | for(int c = 0; c < instance.numClasses(); c++) {
156 | oldestTimestamps[c] = this.instancesClass[c].getTimestamp(indexClass[c]);
157 |
158 | if(oldestTimestamps[c] != null && oldestTimestamps[c] < oldestTimestamp) {
159 | oldestTimestamp = oldestTimestamps[c];
160 | nextClass = c;
161 | }
162 | }
163 |
164 | if(nextClass == -1) {
165 | break;
166 | }
167 |
168 | Instance windowInstance = this.instancesClass[nextClass].getInstance(indexClass[nextClass]);
169 |
170 | for (int i = 0; i < this.ensembleBackground.length; i++) {
171 | int k = MiscUtils.poisson(lambdaOption.getValue(), this.classifierRandom);
172 | if (k > 0) {
173 | this.ensembleBackground[i].trainOnInstance(windowInstance, k, this.instancesSeen);
174 | }
175 | }
176 |
177 | indexClass[nextClass]++;
178 |
179 | } while(true);
180 | }
181 | }
182 |
183 | if(this.warningDetected) {
184 | for (int i = 0; i < this.ensembleBackground.length; i++) {
185 | DoubleVector vote = new DoubleVector(this.ensembleBackground[i].getVotesForInstance(instance));
186 | InstanceExample example = new InstanceExample(instance);
187 | this.ensembleBackground[i].evaluator.addResult(example, vote.getArrayRef());
188 |
189 | int k = MiscUtils.poisson(lambda, this.classifierRandom);
190 | if (k > 0) {
191 | this.ensembleBackground[i].trainOnInstance(instance, k, this.instancesSeen);
192 | }
193 | }
194 |
195 | if(this.instancesSeen - this.firstWarningOn == this.evaluator.widthOption.getValue()) {
196 | // Compare the ensemble and the background ensemble. Select the best components
197 |
198 | ArrayList classifiers = new ArrayList();
199 | ArrayList selection = new ArrayList();
200 | ArrayList kappas = new ArrayList();
201 | ArrayList accuracies = new ArrayList();
202 |
203 | for (int i = 0; i < this.ensemble.length; i++) {
204 | classifiers.add(this.ensemble[i]);
205 | kappas.add(this.ensemble[i].evaluator.getKappa());
206 | accuracies.add(this.ensemble[i].evaluator.getAccuracy());
207 | }
208 |
209 | for (int i = 0; i < this.ensembleBackground.length; i++) {
210 | classifiers.add(this.ensembleBackground[i]);
211 | kappas.add(this.ensembleBackground[i].evaluator.getKappa());
212 | accuracies.add(this.ensembleBackground[i].evaluator.getAccuracy());
213 | }
214 |
215 | for (int i = 0; i < this.ensemble.length; i++) {
216 |
217 | double maxKappaAccuracy = -1;
218 | int maxKappaAccuracyClassifier = -1;
219 |
220 | for (int j = 0; j < this.ensemble.length + this.ensembleBackground.length - i; j++) {
221 | if(kappas.get(j) * accuracies.get(j) >= maxKappaAccuracy) {
222 | maxKappaAccuracy = kappas.get(j) * accuracies.get(j);
223 | maxKappaAccuracyClassifier = j;
224 | }
225 | }
226 |
227 | selection.add(classifiers.get(maxKappaAccuracyClassifier));
228 |
229 | classifiers.remove(maxKappaAccuracyClassifier);
230 | kappas.remove(maxKappaAccuracyClassifier);
231 | accuracies.remove(maxKappaAccuracyClassifier);
232 | }
233 |
234 | for (int i = 0; i < this.ensemble.length; i++) {
235 | this.ensemble[i] = selection.get(i);
236 | }
237 |
238 | for (int i = 0; i < this.ensembleBackground.length; i++) {
239 | this.ensembleBackground[i] = null;
240 | }
241 |
242 | this.warningDetected = false;
243 | }
244 | }
245 | }
246 |
247 | @Override
248 | public double[] getVotesForInstance(Instance instance) {
249 | if(this.ensemble == null) {
250 | initEnsemble(instance);
251 | }
252 |
253 | DoubleVector combinedVote = new DoubleVector();
254 | DoubleVector combinedVoteUnweighted = new DoubleVector();
255 |
256 | for(int i = 0; i < this.ensemble.length; i++) {
257 | DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
258 | if (vote.sumOfValues() > 0.0) {
259 | vote.normalize();
260 |
261 | combinedVoteUnweighted.addValues(vote);
262 |
263 | double kappa = this.ensemble[i].evaluator.getKappa();
264 | double accuracy = this.ensemble[i].evaluator.getAccuracy();
265 |
266 | if(kappa > 0.0) {
267 | vote.scaleValues(kappa * accuracy);
268 | combinedVote.addValues(vote);
269 | }
270 | }
271 | }
272 |
273 | if(combinedVote.sumOfValues() == 0) {
274 | return combinedVoteUnweighted.getArrayRef();
275 | } else {
276 | return combinedVote.getArrayRef();
277 | }
278 | }
279 |
280 | @Override
281 | public boolean isRandomizable() {
282 | return true;
283 | }
284 |
285 | @Override
286 | public void getModelDescription(StringBuilder arg0, int arg1) {
287 | }
288 |
289 | @Override
290 | protected Measurement[] getModelMeasurementsImpl() {
291 | return null;
292 | }
293 |
294 | protected void initEnsemble(Instance instance) {
295 |
296 | this.instancesClass = new InstancesTimestap[instance.numClasses()];
297 |
298 | for(int i = 0; i < instance.numClasses(); i++) {
299 | this.instancesClass[i] = new InstancesTimestap(this.getModelContext());
300 | }
301 |
302 | classSize = new double[instance.numClasses()];
303 |
304 | for (int i=0; i < classSize.length; i++) {
305 | classSize[i] = 1d / classSize.length;
306 | }
307 |
308 | int ensembleSize = this.ensembleSizeOption.getValue();
309 | int numberAttributes = instance.numAttributes() - 1;
310 |
311 | this.ensemble = new ROSEBaseLearner[ensembleSize];
312 | this.ensembleBackground = new ROSEBaseLearner[ensembleSize];
313 |
314 | WindowImbalancedClassificationPerformanceEvaluator classificationEvaluator = new WindowImbalancedClassificationPerformanceEvaluator();
315 | RandomSubspaceHT treeLearner = (RandomSubspaceHT) getPreparedClassOption(this.treeLearnerOption);
316 | treeLearner.resetLearning();
317 |
318 | // Primary ensemble
319 | for(int i = 0; i < ensembleSize; i++) {
320 | int subspaceSize = -1;
321 |
322 | if(featureSpaceOption.getValue() == 1) {
323 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2));
324 | } else if(featureSpaceOption.getValue() == 2) {
325 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5);
326 | }
327 |
328 | if (subspaceSize > numberAttributes) {
329 | subspaceSize = numberAttributes;
330 | } else if (subspaceSize <= 0) {
331 | subspaceSize = 1;
332 | }
333 |
334 | treeLearner.subspaceSizeOption.setValue(subspaceSize);
335 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE));
336 |
337 | this.ensemble[i] = new ROSEBaseLearner(
338 | (RandomSubspaceHT) treeLearner.copy(),
339 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(),
340 | this.instancesSeen,
341 | driftDetectionMethodOption,
342 | warningDetectionMethodOption,
343 | false);
344 | }
345 |
346 | // Background ensemble
347 | for(int i = 0; i < this.ensembleSizeOption.getValue(); i++) {
348 | int subspaceSize = -1;
349 |
350 | if(featureSpaceOption.getValue() == 1) {
351 | subspaceSize = 1 + (int) Math.floor(numberAttributes/2) + this.classifierRandom.nextInt((int)Math.ceil(numberAttributes/2));
352 | } else if(featureSpaceOption.getValue() == 2) {
353 | subspaceSize = (int) Math.round(this.percentageFeaturesMean.getValue() * numberAttributes + ((1.0 - this.percentageFeaturesMean.getValue()) * numberAttributes) * this.classifierRandom.nextGaussian() * 0.5);
354 | }
355 |
356 | if (subspaceSize > numberAttributes) {
357 | subspaceSize = numberAttributes;
358 | } else if (subspaceSize <= 0) {
359 | subspaceSize = 1;
360 | }
361 |
362 | treeLearner.subspaceSizeOption.setValue(subspaceSize);
363 | treeLearner.setRandomSeed(this.classifierRandom.nextInt(Integer.MAX_VALUE));
364 |
365 | this.ensembleBackground[i] = new ROSEBaseLearner(
(RandomSubspaceHT) treeLearner.copy(),
366 | (WindowImbalancedClassificationPerformanceEvaluator) classificationEvaluator.copy(),
367 | this.instancesSeen,
368 | driftDetectionMethodOption,
369 | warningDetectionMethodOption,
370 | false);
371 | }
372 | }
373 |
374 | @Override
375 | public ImmutableCapabilities defineImmutableCapabilities() {
376 | if (this.getClass() == ROSE.class)
377 | return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
378 | else
379 | return new ImmutableCapabilities(Capability.VIEW_STANDARD);
380 | }
381 |
382 | @Override
383 | public String getPurposeString() {
384 | return "ROSE: Robust Online Self-Adjusting Ensemble for Continual Learning on Imbalanced Drifting Data Streams";
385 | }
386 |
387 | /**
388 | * Inner class that represents a single tree member of the forest.
389 | * It contains some analysis information, such as the numberOfDriftsDetected,
390 | */
391 | protected final class ROSEBaseLearner extends AbstractMOAObject {
392 | private static final long serialVersionUID = -3758478930527651262L;
393 | public long createdOn;
394 | public boolean warningDetected;
395 | public long lastDriftOn;
396 | public long lastWarningOn;
397 | public RandomSubspaceHT classifier;
398 | public boolean isBackgroundLearner;
399 |
400 | // The drift and warning object parameters.
401 | protected ClassOption driftOption;
402 | protected ClassOption warningOption;
403 |
404 | // Drift and warning detection
405 | protected ChangeDetector driftDetectionMethod;
406 | protected ChangeDetector warningDetectionMethod;
407 |
408 | // Bkg learner
409 | protected ROSEBaseLearner bkgLearner;
410 | // Statistics
411 | public WindowImbalancedClassificationPerformanceEvaluator evaluator;
412 | protected int numberOfDriftsDetected;
413 | protected int numberOfWarningsDetected;
414 |
415 | public ROSEBaseLearner (RandomSubspaceHT instantiatedClassifier, WindowImbalancedClassificationPerformanceEvaluator evaluatorInstantiated, long instancesSeen, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) {
416 | this.createdOn = instancesSeen;
417 | this.lastDriftOn = 0;
418 | this.lastWarningOn = 0;
419 | this.warningDetected = false;
420 |
421 | this.classifier = instantiatedClassifier;
422 | this.evaluator = evaluatorInstantiated;
423 |
424 | this.numberOfDriftsDetected = 0;
425 | this.numberOfWarningsDetected = 0;
426 | this.isBackgroundLearner = isBackgroundLearner;
427 |
428 | this.driftOption = driftOption;
429 | this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy();
430 |
431 | this.warningOption = warningOption;
432 | this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy();
433 | }
434 |
435 | public void reset() {
436 | if(this.bkgLearner != null) {
437 | this.classifier = this.bkgLearner.classifier;
438 |
439 | this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
440 | this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
441 |
442 | this.evaluator = this.bkgLearner.evaluator;
443 | this.createdOn = this.bkgLearner.createdOn;
444 | this.bkgLearner = null;
445 | }
446 | else {
447 | this.classifier.resetLearning();
448 | this.createdOn = instancesSeen;
449 | this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy();
450 | }
451 |
452 | this.lastWarningOn = 0;
453 | this.lastDriftOn = 0;
454 | this.warningDetected = false;
455 | this.evaluator.reset();
456 | }
457 |
458 | public void trainOnInstance(Instance instance, double weight, long instancesSeen) {
459 | Instance weightedInstance = (Instance) instance.copy();
460 | weightedInstance.setWeight(instance.weight() * weight);
461 | this.classifier.trainOnInstance(weightedInstance);
462 |
463 | if(this.bkgLearner != null)
464 | this.bkgLearner.classifier.trainOnInstance(weightedInstance);
465 |
466 | // Should it use a drift detector? Also, is it a backgroundLearner? If so, then do not "incept" another one.
467 | if(!this.isBackgroundLearner) {
468 | boolean correctlyClassifies = this.classifier.correctlyClassifies(instance);
469 | // Check for warning only if useBkgLearner is active
470 | // Update the warning detection method
471 | this.warningDetectionMethod.input(correctlyClassifies ? 0 : 1);
472 | // Check if there was a change
473 | if(this.warningDetectionMethod.getChange()) {
474 | this.warningDetected = true;
475 | this.lastWarningOn = instancesSeen;
476 | this.numberOfWarningsDetected++;
477 | // Create a new bkgTree classifier
478 | RandomSubspaceHT bkgClassifier = (RandomSubspaceHT) this.classifier.copy();
479 | bkgClassifier.resetLearning();
480 | bkgClassifier.setup(this.classifier.listAttributes, this.classifier.instanceHeader);
481 |
482 | // Resets the evaluator
483 | WindowImbalancedClassificationPerformanceEvaluator bkgEvaluator = (WindowImbalancedClassificationPerformanceEvaluator) this.evaluator.copy();
484 | bkgEvaluator.reset();
485 |
486 | // Create a new bkgLearner object
487 | this.bkgLearner = new ROSEBaseLearner(bkgClassifier, bkgEvaluator, instancesSeen, this.driftOption, this.warningOption, true);
488 |
489 | // Update the warning detection object for the current object
490 | // (this effectively resets changes made to the object while it was still a bkg learner).
491 | this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy();
492 | } else {
493 | this.warningDetected = false;
494 | }
495 |
496 | /*********** drift detection ***********/
497 |
498 | // Update the DRIFT detection method
499 | this.driftDetectionMethod.input(correctlyClassifies ? 0 : 1);
500 | // Check if there was a change
501 | if(this.driftDetectionMethod.getChange()) {
502 | this.lastDriftOn = instancesSeen;
503 | this.numberOfDriftsDetected++;
504 | this.reset();
505 | }
506 | }
507 | }
508 |
509 | public double[] getVotesForInstance(Instance instance) {
510 | DoubleVector vote = new DoubleVector(this.classifier.getVotesForInstance(instance));
511 | return vote.getArrayRef();
512 | }
513 |
514 | @Override
515 | public void getDescription(StringBuilder sb, int indent) {
516 | }
517 | }
518 |
519 | private class InstancesTimestap {
520 | private Instances instances;
521 | private ArrayList timestamps;
522 |
523 | public InstancesTimestap(InstancesHeader header) {
524 | instances = new Instances(header);
525 | timestamps = new ArrayList();
526 | }
527 |
528 | public void add(Instance instance, long timestamp) {
529 | instances.add(instance);
530 | timestamps.add(timestamp);
531 |
532 | if (instances.size() > windowSizeOption.getValue() / instance.numClasses()) {
533 | instances.delete(0);
534 | timestamps.remove(0);
535 | }
536 | }
537 |
538 | public Instance getInstance(int index) {
539 | if(index >= instances.size())
540 | return null;
541 | else
542 | return instances.get(index);
543 | }
544 |
545 | public Long getTimestamp(int index) {
546 | if(index >= timestamps.size())
547 | return null;
548 | else
549 | return timestamps.get(index);
550 | }
551 | }
552 | }
--------------------------------------------------------------------------------
/src/main/java/moa/classifiers/trees/RandomSubspaceHT.java:
--------------------------------------------------------------------------------
1 | package moa.classifiers.trees;
2 |
3 | import java.util.ArrayList;
4 |
5 | import com.github.javacliparser.IntOption;
6 | import com.yahoo.labs.samoa.instances.Instance;
7 | import com.yahoo.labs.samoa.instances.Instances;
8 |
9 | /**
10 | * Hoeffding Tree on a fixed random subspace of features
11 | *
12 | * @author Alberto Cano
13 | */
14 | public class RandomSubspaceHT extends HoeffdingTree {
15 |
16 | private static final long serialVersionUID = 1L;
17 |
18 | public IntOption subspaceSizeOption = new IntOption("subspaceFeaturesSize", 'k', "Number of features", 1, 1, Integer.MAX_VALUE);
19 |
20 | public boolean[] listAttributes;
21 |
22 | public Instances instanceHeader;
23 |
24 | public void setup(boolean[] listAttibutes, Instances instanceHeader) {
25 | this.listAttributes = listAttibutes;
26 | this.instanceHeader = instanceHeader;
27 | }
28 |
29 | @Override
30 | public String getPurposeString() {
31 | return "ROSE Hoeffding Tree for data streams";
32 | }
33 |
34 | @Override
35 | public void resetLearningImpl() {
36 | super.resetLearningImpl();
37 | }
38 |
39 | @Override
40 | public void trainOnInstanceImpl(Instance instance) {
41 | if(this.listAttributes == null) {
42 | setupListAttributes(instance);
43 | }
44 |
45 | Instance instanceProjected = instance.copy();
46 |
47 | for(int att = instanceProjected.numAttributes()-2; att >= 0; att--) {
48 | if(this.listAttributes[att] == false) {
49 | instanceProjected.deleteAttributeAt(att);
50 | }
51 | }
52 |
53 | instanceProjected.setDataset(instanceHeader);
54 |
55 | super.trainOnInstanceImpl(instanceProjected);
56 | }
57 |
58 | @Override
59 | public double[] getVotesForInstance(Instance instance) {
60 | if(this.listAttributes == null) {
61 | setupListAttributes(instance);
62 | }
63 |
64 | Instance instanceProjected = instance.copy();
65 |
66 | for(int att = instanceProjected.numAttributes()-2; att >= 0; att--) {
67 | if(this.listAttributes[att] == false) {
68 | instanceProjected.deleteAttributeAt(att);
69 | }
70 | }
71 |
72 | instanceProjected.setDataset(instanceHeader);
73 |
74 | return super.getVotesForInstance(instanceProjected);
75 | }
76 |
77 | @Override
78 | public boolean isRandomizable() {
79 | return true;
80 | }
81 |
82 | protected void setupListAttributes(Instance instance) {
83 | int numberAttributes = instance.numAttributes() - 1;
84 | int subspaceSize = subspaceSizeOption.getValue();
85 |
86 | this.listAttributes = new boolean[numberAttributes];
87 | this.instanceHeader = new Instances(instance.dataset());
88 |
89 | ArrayList attributesPool = new ArrayList();
90 |
91 | for(int i = 0; i < numberAttributes; i++) {
92 | attributesPool.add(i);
93 | }
94 |
95 | for(int i = 0; i < subspaceSize; i++) {
96 | this.listAttributes[attributesPool.remove(this.classifierRandom.nextInt(attributesPool.size()))] = true;
97 | }
98 |
99 | for(int att = numberAttributes-1; att >= 0; att--) {
100 | if(this.listAttributes[att] == false) {
101 | this.instanceHeader.deleteAttributeAt(att);
102 | }
103 | }
104 | }
105 | }
--------------------------------------------------------------------------------
/src/main/java/moa/evaluation/WindowAUCImbalancedPerformanceEvaluator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * ImbalancedPerformanceEvaluator.java
3 | * Copyright (C) 2016 Poznan University of Technology
4 | * @author Dariusz Brzezinski (dbrzezinski@cs.put.poznan.pl)
5 | * @author Tomasz Pewinski
6 | *
7 | * Licensed under the Apache License, Version 2.0 (the "License");
8 | * you may not use this file except in compliance with the License.
9 | * You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | */
19 | package moa.evaluation;
20 |
21 | import java.util.TreeSet;
22 |
23 | import moa.core.Example;
24 | import moa.core.Measurement;
25 | import moa.core.ObjectRepository;
26 | import moa.core.Utils;
27 | import moa.options.AbstractOptionHandler;
28 |
29 | import com.github.javacliparser.IntOption;
30 | import com.yahoo.labs.samoa.instances.Instance;
31 | import com.yahoo.labs.samoa.instances.InstanceImpl;
32 | import com.yahoo.labs.samoa.instances.Prediction;
33 |
34 | import moa.tasks.TaskMonitor;
35 |
36 | /**
37 | * Classification evaluator that updates evaluation results using a sliding
38 | * window. Performance measures designed for class imbalance problems.
39 | * Only to be used for binary classification problems with unweighted instances.
40 | * Class 0 - majority/negative examples, class 1 - minority, positive examples.
41 | * Prequential AUC calculation as described and analyzed in: D. Brzezinski,
42 | * J. Stefanowski, "Prequential AUC: Properties of the Area Under the ROC
43 | * Curve for Data Streams with Concept Drift", Knowledge and Information
44 | * Systems, 2017.
45 | *
46 | * @author Dariusz Brzezinski (dbrzezinski at cs.put.poznan.pl)
47 | * @author Tomasz Pewinski
48 | */
49 | public class WindowAUCImbalancedPerformanceEvaluator extends
50 | AbstractOptionHandler implements ClassificationPerformanceEvaluator {
51 |
52 | private static final long serialVersionUID = 1L;
53 |
54 | public IntOption widthOption = new IntOption("width", 'w',
55 | "Size of Window", 1000);
56 |
57 | protected double totalObservedInstances = 0;
58 | private Estimator aucEstimator;
59 | private SimpleEstimator weightMajorityClassifier;
60 | protected int numClasses;
61 |
62 | public class SimpleEstimator {
63 | protected double len;
64 |
65 | protected double sum;
66 |
67 | public void add(double value) {
68 | sum += value;
69 | len++;
70 | }
71 |
72 | public double estimation(){
73 | return sum/len;
74 | }
75 | }
76 |
77 | public class Estimator {
78 |
79 | public class Score implements Comparable {
80 | /**
81 | * Predicted score of the example
82 | */
83 | protected double value;
84 |
85 | /**
86 | * Age of example - position in the window where the example was
87 | * added
88 | */
89 | protected int posWindow;
90 |
91 | /**
92 | * True if example's true label is positive
93 | */
94 | protected boolean isPositive;
95 |
96 | /**
97 | * Constructor.
98 | *
99 | * @param value
100 | * score value
101 | * @param position
102 | * score position in window (defines its age)
103 | * @param isPositive
104 | * true if the example's true label is positive
105 | */
106 | public Score(double value, int position, boolean isPositive) {
107 | this.value = value;
108 | this.posWindow = position;
109 | this.isPositive = isPositive;
110 | }
111 |
112 | /**
113 | * Sort descending based on score value.
114 | */
115 | @Override
116 | public int compareTo(Score o) {
117 | if (o.value < this.value) {
118 | return -1;
119 | } else if (o.value > this.value){
120 | return 1;
121 | } else {
122 | if (!o.isPositive && this.isPositive) {
123 | return -1;
124 | } else if (o.isPositive && !this.isPositive){
125 | return 1;
126 | } else {
127 | if (o.posWindow > this.posWindow) {
128 | return -1;
129 | } else if (o.posWindow < this.posWindow){
130 | return 1;
131 | } else {
132 | return 0;
133 | }
134 | }
135 | }
136 | }
137 |
138 | @Override
139 | public boolean equals(Object o) {
140 | return (o instanceof Score) && ((Score)o).posWindow == this.posWindow;
141 | }
142 | }
143 |
144 | protected TreeSet sortedScores;
145 |
146 | protected TreeSet holdoutSortedScores;
147 |
148 | protected Score[] window;
149 |
150 | protected double[] predictions;
151 |
152 | protected int posWindow;
153 |
154 | protected int size;
155 |
156 | protected double numPos;
157 |
158 | protected double numNeg;
159 |
160 | protected double holdoutNumPos;
161 |
162 | protected double holdoutNumNeg;
163 |
164 | protected double correctPredictions;
165 |
166 | protected double correctPositivePredictions;
167 |
168 | protected double[] columnKappa;
169 |
170 | protected double[] rowKappa;
171 |
172 | protected int confusionMatrix[][];
173 |
174 | public Estimator(int sizeWindow) {
175 | this.sortedScores = new TreeSet();
176 | this.holdoutSortedScores = new TreeSet();
177 | this.size = sizeWindow;
178 | this.window = new Score[sizeWindow];
179 | this.predictions = new double[sizeWindow];
180 |
181 | this.rowKappa = new double[numClasses];
182 | this.columnKappa = new double[numClasses];
183 | this.confusionMatrix = new int[numClasses][numClasses];
184 |
185 | for (int i = 0; i < numClasses; i++) {
186 | this.rowKappa[i] = 0.0;
187 | this.columnKappa[i] = 0.0;
188 | }
189 |
190 | this.posWindow = 0;
191 | this.numPos = 0;
192 | this.numNeg = 0;
193 | this.holdoutNumPos = 0;
194 | this.holdoutNumNeg = 0;
195 | this.correctPredictions = 0;
196 | this.correctPositivePredictions = 0;
197 | }
198 |
199 | public void add(double score, boolean isPositive, boolean correctPrediction) {
200 | // // periodically update holdout evaluation
201 | if (size > 0 && posWindow % this.size == 0) {
202 | this.holdoutSortedScores = new TreeSet();
203 |
204 | for (Score s : this.sortedScores) {
205 | this.holdoutSortedScores.add(s);
206 | }
207 |
208 | this.holdoutNumPos = this.numPos;
209 | this.holdoutNumNeg = this.numNeg;
210 | }
211 |
212 | // // if the window is used and it's full
213 | if (size > 0 && posWindow >= this.size) {
214 | // // remove the oldest example
215 | sortedScores.remove(window[posWindow % size]);
216 | correctPredictions -= predictions[posWindow % size];
217 | correctPositivePredictions -= window[posWindow % size].isPositive ? predictions[posWindow % size] : 0;
218 |
219 | if (window[posWindow % size].isPositive) {
220 | numPos--;
221 | } else {
222 | numNeg--;
223 | }
224 |
225 | int oldestExampleTrueClass = window[posWindow % size].isPositive ? 1 : 0;
226 | int oldestExamplePredictedClass = predictions[posWindow % size] == 1.0 ? oldestExampleTrueClass : Math.abs(oldestExampleTrueClass - 1);
227 |
228 | this.rowKappa[oldestExamplePredictedClass] -= 1;
229 | this.columnKappa[oldestExampleTrueClass] -= 1;
230 | this.confusionMatrix[oldestExamplePredictedClass][oldestExampleTrueClass]--;
231 | }
232 |
233 | // // add new example
234 | Score newScore = new Score(score, posWindow, isPositive);
235 | sortedScores.add(newScore);
236 | correctPredictions += correctPrediction ? 1 : 0;
237 | correctPositivePredictions += correctPrediction && isPositive ? 1 : 0;
238 |
239 | int trueClass = isPositive ? 1 : 0;
240 | int predictedClass = correctPrediction ? trueClass : Math.abs(trueClass - 1);
241 | this.rowKappa[predictedClass] += 1;
242 | this.columnKappa[trueClass] += 1;
243 | this.confusionMatrix[predictedClass][trueClass]++;
244 |
245 |
246 | if (newScore.isPositive) {
247 | numPos++;
248 | } else {
249 | numNeg++;
250 | }
251 |
252 | if (size > 0) {
253 | window[posWindow % size] = newScore;
254 | predictions[posWindow % size] = correctPrediction ? 1 : 0;
255 | }
256 |
257 | //// posWindow needs to be always incremented to differentiate between examples in the red-black tree
258 | posWindow++;
259 | }
260 |
261 | public double getAUC() {
262 | double AUC = 0;
263 | double c = 0;
264 | double prevc = 0;
265 | double lastPosScore = Double.MAX_VALUE;
266 |
267 | if (numPos == 0 || numNeg == 0) {
268 | return 1;
269 | }
270 |
271 | for (Score s : sortedScores){
272 | if(s.isPositive) {
273 | if (s.value != lastPosScore) {
274 | prevc = c;
275 | lastPosScore = s.value;
276 | }
277 |
278 | c += 1;
279 | } else {
280 | if (s.value == lastPosScore) {
281 | // tie
282 | AUC += ((double)(c + prevc))/2.0;
283 | } else {
284 | AUC += c;
285 | }
286 | }
287 | }
288 |
289 | return AUC / (numPos * numNeg);
290 | }
291 |
292 | public double getHoldoutAUC() {
293 | double AUC = 0;
294 | double c = 0;
295 | double prevc = 0;
296 | double lastPosScore = Double.MAX_VALUE;
297 |
298 | if (holdoutSortedScores.isEmpty()) {
299 | return 0;
300 | }
301 |
302 | if (holdoutNumPos == 0 || holdoutNumNeg == 0) {
303 | return 1;
304 | }
305 |
306 | for (Score s : holdoutSortedScores){
307 | if(s.isPositive) {
308 | if (s.value != lastPosScore) {
309 | prevc = c;
310 | lastPosScore = s.value;
311 | }
312 |
313 | c += 1;
314 | } else {
315 | if (s.value == lastPosScore) {
316 | // tie
317 | AUC += ((double)(c + prevc))/2.0;
318 | } else {
319 | AUC += c;
320 | }
321 | }
322 | }
323 |
324 | return AUC / (holdoutNumPos * holdoutNumNeg);
325 | }
326 |
327 | public double getScoredAUC() {
328 | double AOC = 0;
329 | double AUC = 0;
330 | double r = 0;
331 | double prevr = 0;
332 | double c = 0;
333 | double prevc = 0;
334 | double R_plus, R_minus;
335 | double lastPosScore = Double.MAX_VALUE;
336 | double lastNegScore = Double.MAX_VALUE;
337 |
338 | if (numPos == 0 || numNeg == 0) {
339 | return 1;
340 | }
341 |
342 | for (Score s : sortedScores){
343 | if(s.isPositive) {
344 | if (s.value != lastPosScore) {
345 | prevc = c;
346 | lastPosScore = s.value;
347 | }
348 |
349 | c += s.value;
350 |
351 | if (s.value == lastNegScore) {
352 | // tie
353 | AOC += ((double)(r + prevr))/2.0;
354 | } else {
355 | AOC += r;
356 | }
357 | } else {
358 | if (s.value != lastNegScore) {
359 | prevr = r;
360 | lastNegScore = s.value;
361 | }
362 |
363 | r += s.value;
364 |
365 | if (s.value == lastPosScore) {
366 | // tie
367 | AUC += ((double)(c + prevc))/2.0;
368 | } else {
369 | AUC += c;
370 | }
371 | }
372 | }
373 |
374 | R_minus = (numPos*r - AOC)/(numPos * numNeg);
375 | R_plus = (AUC)/(numPos * numNeg);
376 | return R_plus - R_minus;
377 | }
378 |
379 | public double getRatio() {
380 | if(numNeg == 0) {
381 | return Double.MAX_VALUE;
382 | } else {
383 | return numPos/numNeg;
384 | }
385 | }
386 |
387 | public double getAccuracy() {
388 | if (size > 0) {
389 | return totalObservedInstances > 0.0 ? correctPredictions / Math.min(size, totalObservedInstances) : 0.0;
390 | } else {
391 | return totalObservedInstances > 0.0 ? correctPredictions / totalObservedInstances : 0.0;
392 | }
393 | }
394 |
395 | public double getKappa() {
396 | double p0 = getAccuracy();
397 | double pc = 0.0;
398 |
399 | if (size > 0) {
400 | for (int i = 0; i < numClasses; i++) {
401 | pc += (this.rowKappa[i]/Math.min(size, totalObservedInstances)) * (this.columnKappa[i]/Math.min(size, totalObservedInstances));
402 | }
403 | } else {
404 | for (int i = 0; i < numClasses; i++) {
405 | pc += (this.rowKappa[i]/totalObservedInstances) * (this.columnKappa[i]/totalObservedInstances);
406 | }
407 | }
408 | return (p0 - pc) / (1.0 - pc);
409 | }
410 |
411 | private double getKappaM() {
412 | double p0 = getAccuracy();
413 | double pc = weightMajorityClassifier.estimation();
414 |
415 | return (p0 - pc) / (1.0 - pc);
416 | }
417 |
418 | public double getGMean() {
419 | double positiveAccuracy = correctPositivePredictions / numPos;
420 | double negativeAccuracy = (correctPredictions - correctPositivePredictions) / numNeg;
421 | return Math.sqrt(positiveAccuracy * negativeAccuracy);
422 | }
423 |
424 | public double getRecall() {
425 | return correctPositivePredictions / numPos;
426 | }
427 | }
428 |
429 | @Override
430 | public void reset() {
431 | reset(this.numClasses);
432 | }
433 |
434 | public void reset(int numClasses) {
435 | if (numClasses != 2) {
436 | throw new RuntimeException(
437 | "Too many classes ("
438 | + numClasses
439 | + "). AUC evaluation can be performed only for two-class problems!");
440 | }
441 |
442 | this.numClasses = numClasses;
443 |
444 | this.aucEstimator = new Estimator(this.widthOption.getValue());
445 | this.weightMajorityClassifier = new SimpleEstimator();
446 | this.totalObservedInstances = 0;
447 | }
448 |
449 | @Override
450 | public void addResult(Example exampleInstance, double[] classVotes) {
451 | InstanceImpl inst = (InstanceImpl)exampleInstance.getData();
452 | double weight = inst.weight();
453 |
454 | if (inst.classIsMissing() == false){
455 | int trueClass = (int) inst.classValue();
456 |
457 | if (weight > 0.0) {
458 | // // initialize evaluator
459 | if (totalObservedInstances == 0) {
460 | reset(inst.dataset().numClasses());
461 | }
462 | this.totalObservedInstances += 1;
463 |
464 | //// if classVotes has length == 1, then the negative (0) class got all the votes
465 | Double normalizedVote = 0.0;
466 |
467 | //// normalize and add score
468 | if(classVotes.length == 2) {
469 | normalizedVote = classVotes[1]/(classVotes[0] + classVotes[1]);
470 | }
471 |
472 | if(normalizedVote.isNaN()){
473 | normalizedVote = 0.0;
474 | }
475 |
476 | this.aucEstimator.add(normalizedVote, trueClass == 1, Utils.maxIndex(classVotes) == trueClass);
477 | this.weightMajorityClassifier.add((this.aucEstimator.getRatio() <= 1 ? 0 : 1) == trueClass ? weight: 0);
478 | }
479 | }
480 | }
481 |
482 | protected double AUC(int[][] confusionMatrix, int positiveClass, int negativeClass)
483 | {
484 | int tp = confusionMatrix[positiveClass][positiveClass];
485 | int fp = confusionMatrix[positiveClass][negativeClass];
486 | int tn = confusionMatrix[negativeClass][negativeClass];
487 | int fn = confusionMatrix[negativeClass][positiveClass];
488 |
489 | double tpRate = 1.0, fpRate = 0.0;
490 |
491 | if(tp + fn != 0)
492 | tpRate = tp / (double) (tp + fn);
493 |
494 | if(fp + tn != 0)
495 | fpRate = fp / (double) (fp + tn);
496 |
497 | double auc = (1.0 + tpRate - fpRate) / 2.0;
498 |
499 | return auc;
500 | }
501 |
502 | @Override
503 | public Measurement[] getPerformanceMeasurements() {
504 |
505 | Measurement[] measurement = new Measurement[11 + numClasses*numClasses];
506 |
507 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances);
508 | measurement[1] = new Measurement("AUC", AUC(this.aucEstimator.confusionMatrix, 1, 0));
509 | measurement[1] = new Measurement("pAUC", this.aucEstimator.getAUC());
510 | measurement[2] = new Measurement("sAUC", this.aucEstimator.getScoredAUC());
511 | measurement[3] = new Measurement("Accuracy", this.aucEstimator.getAccuracy());
512 | measurement[4] = new Measurement("Kappa", this.aucEstimator.getKappa());
513 | measurement[5] = new Measurement("Periodical holdout AUC", this.aucEstimator.getHoldoutAUC());
514 | measurement[6] = new Measurement("Pos/Neg ratio", this.aucEstimator.getRatio());
515 | measurement[7] = new Measurement("G-Mean", this.aucEstimator.getGMean());
516 | measurement[8] = new Measurement("Recall", this.aucEstimator.getRecall());
517 | measurement[9] = new Measurement("KappaM", this.aucEstimator.getKappaM());
518 | measurement[10] = new Measurement("Classes", numClasses);
519 |
520 | for(int i = 0; i < numClasses; i++) {
521 | for(int j = 0; j < numClasses; j++) {
522 | measurement[11 + i*numClasses + j] = new Measurement("CM["+i+"]["+j+"]", this.aucEstimator.confusionMatrix[i][j]);
523 | }
524 | }
525 |
526 | return measurement;
527 | }
528 |
529 | @Override
530 | public void getDescription(StringBuilder sb, int indent) {
531 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
532 | sb, indent);
533 | }
534 |
535 | @Override
536 | public void prepareForUseImpl(TaskMonitor monitor,
537 | ObjectRepository repository) {
538 | }
539 |
540 | public Estimator getAucEstimator() {
541 | return aucEstimator;
542 | }
543 |
544 | @Override
545 | public void addResult(Example arg0, Prediction arg1) {
546 | throw new RuntimeException("Designed for scoring classifiers");
547 | }
548 | }
549 |
--------------------------------------------------------------------------------
/src/main/java/moa/evaluation/WindowAUCMultiClassImbalancedPerformanceEvaluator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * WindowClassificationPerformanceEvaluator.java
3 | * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
4 | * @author Albert Bifet (abifet@cs.waikato.ac.nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.evaluation;
21 |
22 | import java.io.Serializable;
23 | import java.util.ArrayList;
24 | import java.util.Iterator;
25 | import java.util.List;
26 | import java.util.TreeSet;
27 |
28 | import moa.core.Example;
29 | import moa.core.Measurement;
30 | import moa.core.ObjectRepository;
31 | import moa.options.AbstractOptionHandler;
32 | import moa.tasks.TaskMonitor;
33 |
34 | import com.github.javacliparser.IntOption;
35 | import com.yahoo.labs.samoa.instances.Instance;
36 | import com.yahoo.labs.samoa.instances.Prediction;
37 |
38 | import weka.core.Utils;
39 |
40 |
41 | /**
42 | * Classification evaluator that updates evaluation results using a sliding
43 | * window. Only to be used for binary classification problems with unweighted instances.
44 | *
45 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
46 | * @version $Revision: 7 $
47 | */
48 | public class WindowAUCMultiClassImbalancedPerformanceEvaluator extends AbstractOptionHandler implements ClassificationPerformanceEvaluator, Serializable {
49 |
50 | private static final long serialVersionUID = 1L;
51 |
52 | public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 500);
53 |
54 | protected double totalObservedInstances = 0;
55 |
56 | protected Estimator aucEstimator;
57 |
58 | protected int numClasses;
59 |
60 | public class Estimator implements Serializable {
61 |
62 | public class Score implements Comparable, Serializable {
63 | /**
64 | * Predicted score of the example
65 | */
66 | protected double[] value;
67 |
68 | /**
69 | * true class label index of the example
70 | */
71 | protected int realClass;
72 |
73 | /**
74 | * Age of example - position in the window where the example was
75 | * added
76 | */
77 | protected int posWindow;
78 |
79 | //positive class index for this tree
80 | protected int pos_class;
81 | /**
82 | * Constructor.
83 | *
84 | * @param value
85 | * score value
86 | * @param position
87 | * score position in window (defines its age)
88 | * @param isPositive
89 | * true if the example's true label is positive
90 | */
91 | public Score(int trueClass, double[] value, int position, int pos_class) {
92 | this.realClass = trueClass;
93 | this.value = value;
94 | this.posWindow = position;
95 | this.pos_class = pos_class;
96 | }
97 |
98 | /**
99 | * Sort descending based on score value of the positive class of the current tree.
100 | */
101 | @Override
102 | public int compareTo(Score o) {
103 | if (o.value[pos_class] < this.value[pos_class]) {
104 | return -1;
105 | } else if (o.value[pos_class] > this.value[pos_class]){
106 | return 1;
107 | } else {
108 | if (o.posWindow > this.posWindow) {
109 | return -1;
110 | } else if (o.posWindow < this.posWindow){
111 | return 1;
112 | } else {
113 | return 0;
114 | }
115 | }
116 | }
117 |
118 | @Override
119 | public boolean equals(Object o) {
120 | return (o instanceof Score) && ((Score)o).posWindow == this.posWindow;
121 | }
122 |
123 | }
124 |
125 | public class RB_tree{
126 | protected int pos_class;
127 | protected int neg_class;
128 | protected double numPos;
129 | protected double numNeg;
130 | protected TreeSet sortedScores;
131 |
132 | public RB_tree(int pos_class_idx, int neg_class_idx) {
133 | pos_class = pos_class_idx;
134 | neg_class = neg_class_idx;
135 | numPos = 0;
136 | numNeg = 0;
137 | sortedScores = new TreeSet();
138 | }
139 | }
140 |
141 | protected List rb_trees;//the set of red-black trees for any 2 classes (numClasses*(numClasses-1))
142 |
143 | protected Score[] window;//store current window of examples
144 |
145 | protected double[] predictions;//store correct/incorrect prediction results for all examples in the window: if the example is correctly classified or not: 1-correct, 0-incorrect.
146 |
147 | protected int[] predictionsClass;
148 |
149 | protected int posWindow;//the position where the oldest example that needs to be removed (if window is full) and where the new example that needs to be added.
150 |
151 | protected int size;//window size
152 |
153 | protected double correctPredictions;//the number of correctly classified examples in the current window
154 |
155 | protected double[] correctPrediction_perclass; //the number of correctly classified examples in the current window, for each class (for calculating recall and g-mean)
156 |
157 | protected double[] totalObservedInstances_perclass_window;//the number of examples in the current window, for each class (for calculating recall and g-mean)
158 |
159 | protected double[] inst_classvotes;//the normalized votes for the current data sample
160 |
161 | protected int inst_trueClass;//the true class of the current data sample
162 |
163 | protected WindowAUCImbalancedPerformanceEvaluator[] auc_2c;//store 2-class version of AUC for each class being the positive class (for calculating Provost's weighted AUC)
164 |
165 | protected WindowAUCImbalancedPerformanceEvaluator[] ewauc_2c;//store 2-class version of AUC for each class being the positive class (for calculating Provost's weighted AUC but with equal weight)
166 |
167 | protected double[] columnKappa;
168 | protected double[] rowKappa;
169 | protected int confusionMatrix[][];
170 |
171 | public Estimator(int sizeWindow) {
172 |
173 | this.rb_trees = new ArrayList();
174 | for(int i = 0; i < numClasses-1; i++) {
175 | for(int j = i+1; j < numClasses; j++) {
176 | RB_tree tree_i = new RB_tree(i,j);//a redblack tree with class i as the positive class
177 | RB_tree tree_j = new RB_tree(j,i);//a redblack tree with class j as the positive class
178 | this.rb_trees.add(tree_i);
179 | this.rb_trees.add(tree_j);
180 | }
181 | }
182 |
183 | this.size = sizeWindow;
184 | this.window = new Score[sizeWindow];
185 | this.predictions = new double[sizeWindow];
186 | this.predictionsClass = new int[sizeWindow];
187 | this.correctPrediction_perclass = new double[numClasses];
188 | this.totalObservedInstances_perclass_window = new double[numClasses];
189 | this.auc_2c = new WindowAUCImbalancedPerformanceEvaluator[numClasses];
190 | this.ewauc_2c = new WindowAUCImbalancedPerformanceEvaluator[numClasses];
191 |
192 | this.rowKappa = new double[numClasses];
193 | this.columnKappa = new double[numClasses];
194 | this.confusionMatrix = new int[numClasses][numClasses];
195 |
196 | this.posWindow = 0;
197 | this.correctPredictions = 0;
198 | for(int i = 0; i < numClasses; i++) {
199 | correctPrediction_perclass[i] = 0;
200 | totalObservedInstances_perclass_window[i] = 0;
201 |
202 | auc_2c[i] = new WindowAUCImbalancedPerformanceEvaluator();
203 | auc_2c[i].widthOption.setValue(size);
204 | auc_2c[i].reset(2);
205 |
206 |
207 | ewauc_2c[i] = new WindowAUCImbalancedPerformanceEvaluator();
208 | ewauc_2c[i].widthOption.setValue(size);
209 | ewauc_2c[i].reset(2);
210 | }
211 |
212 | }
213 |
214 | public void add(double[] score, int trueClass, boolean correctPrediction, int predictedClass) {
215 |
216 | this.inst_classvotes = score;
217 | this.inst_trueClass = trueClass;
218 |
219 | int[] tree_idx_add = this.find_trees(trueClass);
220 | // if the window is used and it's full
221 | if (size > 0 && posWindow >= this.size) {
222 | int idx_remove = window[posWindow % size].realClass;//the class label of the example to be removed from the window
223 | int[] tree_idx_remove = this.find_trees(idx_remove);//find the indices of trees containing the class label of the example to be removed from the window
224 | correctPredictions -= predictions[posWindow % size];
225 | correctPrediction_perclass[idx_remove] -= predictions[posWindow % size];
226 | totalObservedInstances_perclass_window[idx_remove] -= 1;
227 | for(int i = 0; i < tree_idx_remove.length; i++) {
228 | RB_tree tree_i = this.rb_trees.get(tree_idx_remove[i]);
229 | // remove the oldest example from the tree with "trueClass"
230 | Score node = this.find_node(tree_i, posWindow);
231 | tree_i.sortedScores.remove(node);
232 |
233 | if (window[posWindow % size].realClass == tree_i.pos_class) {
234 | tree_i.numPos--;
235 | } else {
236 | tree_i.numNeg--;
237 | }
238 | }
239 |
240 | this.rowKappa[predictionsClass[posWindow % size]] -= 1;
241 | this.columnKappa[idx_remove] -= 1;
242 | this.confusionMatrix[predictionsClass[posWindow % size]][idx_remove]--;
243 | }
244 |
245 | // add new example
246 | Score newScore = new Score(trueClass, score, posWindow, -1);//new score for updating windows
247 | correctPredictions += correctPrediction ? 1 : 0;
248 | correctPrediction_perclass[trueClass] += correctPrediction ? 1 : 0;
249 | totalObservedInstances_perclass_window[trueClass] += 1;
250 | this.rowKappa[predictedClass] += 1;
251 | this.columnKappa[trueClass] += 1;
252 | this.confusionMatrix[predictedClass][trueClass]++;
253 |
254 | if (size > 0) {
255 | window[posWindow % size] = newScore;
256 | predictions[posWindow % size] = correctPrediction ? 1 : 0;
257 | predictionsClass[posWindow % size] = predictedClass;
258 | }
259 | for(int i = 0; i < tree_idx_add.length; i++) {
260 | RB_tree tree_i = this.rb_trees.get(tree_idx_add[i]);
261 | Score newScore_tree = new Score(trueClass, score, posWindow, tree_i.pos_class);//new score for updating trees
262 | tree_i.sortedScores.add(newScore_tree);
263 |
264 | if (trueClass == tree_i.pos_class) {
265 | tree_i.numPos++;
266 | } else {
267 | tree_i.numNeg++;
268 | }
269 | }
270 | // posWindow needs to be always incremented to differentiate between examples in the red-black tree
271 | posWindow++;
272 | }
273 |
274 | public Score find_node(RB_tree tree_i, int posWindow) {
275 | Score node;
276 | Iterator iterator = tree_i.sortedScores.iterator();
277 | while(iterator.hasNext()) {
278 | node = iterator.next();
279 | if((node.posWindow % size) == (posWindow % size))
280 | return node;
281 | }
282 | return null;
283 | }
284 |
285 | //return the tree indices of those involving class_idx
286 | public int[] find_trees(int class_idx) {
287 | int[] tree_idx = new int[2*(numClasses-1)];
288 | int t = 0;
289 | for(RB_tree tree_i: this.rb_trees) {
290 | if((tree_i.pos_class == class_idx) || (tree_i.neg_class == class_idx)) {
291 | tree_idx[t] = this.rb_trees.indexOf(tree_i);
292 | t++;
293 | }
294 | }
295 | return tree_idx;
296 | }
297 |
298 | //return the tree with given positive class index and negative class index
299 | public int find_onetree(int pos_class, int neg_class) {
300 | int idx=-1;
301 | for(RB_tree tree_i: this.rb_trees) {
302 | if((tree_i.pos_class == pos_class) && (tree_i.neg_class == neg_class)) {
303 | idx = this.rb_trees.indexOf(tree_i);
304 | break;
305 | }
306 | }
307 | return idx;
308 | }
309 |
310 | public double getPMAUC() {
311 | double pmAUC = 0;
312 | double c = 0;
313 | for(RB_tree tree_i: this.rb_trees) {
314 | c = c + this.getAUC(tree_i);
315 | }
316 | pmAUC = c/(numClasses*(numClasses-1));
317 | return pmAUC;
318 | }
319 |
320 | //calculate AUC for the given tree
321 | public double getAUC(RB_tree current_tree) {
322 | double AUC = 0;
323 | double c = 0;
324 | if (current_tree.numPos == 0 || current_tree.numNeg == 0) {
325 | return 0;
326 | }
327 |
328 | for (Score s : current_tree.sortedScores){
329 | if(s.realClass==current_tree.pos_class) {
330 | c += 1;
331 | } else {
332 | AUC += c;
333 | }
334 | }
335 |
336 | return AUC / (current_tree.numPos * current_tree.numNeg);
337 | }
338 |
339 | //calculate AUC based on the given positive class index and negative class index.
340 | public double getAUC(int pos_class, int neg_class) {
341 | double AUC = 0;
342 | double c = 0;
343 | int idx = this.find_onetree(pos_class, neg_class);
344 | RB_tree current_tree = this.rb_trees.get(idx);
345 |
346 | if (current_tree.numPos == 0 || current_tree.numNeg == 0) {
347 | return 0;
348 | }
349 |
350 | for (Score s : current_tree.sortedScores){
351 | if(s.realClass==pos_class) {
352 | c += 1;
353 | } else {
354 | AUC += c;
355 | }
356 | }
357 |
358 | return AUC / (current_tree.numPos * current_tree.numNeg);
359 | }
360 |
361 |
362 | /**
363 | * Provost and Domingo's weighted AUC [2003Tree Induction for Probability-Based Ranking]:
364 | * compute the expected AUC, which is the weighted average of the AUCs obtained taking each
365 | * class as the reference class in turn (i.e., making it class 0 and all other
366 | * classes class 1). The weight of a class�s AUC is the class�s frequency in the data.
367 | * */
368 | public double getWeightedAUC() {
369 | double wAUC = 0.0;
370 | double[] class_weights = new double[numClasses];
371 |
372 | // get class weights
373 | for(int c = 0; c < numClasses; c++) {
374 | if(totalObservedInstances>0 && totalObservedInstances 0) {
448 | return totalObservedInstances > 0.0 ? correctPredictions / Math.min(size, totalObservedInstances) : 0.0;
449 | } else {
450 | return totalObservedInstances > 0.0 ? correctPredictions / totalObservedInstances : 0.0;
451 | }
452 | }
453 |
454 | public double getKappa() {
455 | double p0 = getAccuracy();
456 | double pc = 0.0;
457 |
458 | if (size > 0) {
459 | for (int i = 0; i < numClasses; i++) {
460 | pc += (this.rowKappa[i]/Math.min(size, totalObservedInstances)) * (this.columnKappa[i]/Math.min(size, totalObservedInstances));
461 | }
462 | } else {
463 | for (int i = 0; i < numClasses; i++) {
464 | pc += (this.rowKappa[i]/totalObservedInstances) * (this.columnKappa[i]/totalObservedInstances);
465 | }
466 | }
467 | return (p0 - pc) / (1.0 - pc);
468 | }
469 |
470 | //Recall of Class i in the current window
471 | public double getRecall(int class_idx) {
472 | return totalObservedInstances_perclass_window[class_idx] > 0.0 ? correctPrediction_perclass[class_idx] / totalObservedInstances_perclass_window[class_idx] : 0.0;
473 | }
474 |
475 | //G-mean in the current window
476 | public double getGmean() {
477 | double gmean = 1.0;
478 | for(int i = 0; i < numClasses; i++) {
479 | gmean = gmean * this.getRecall(i);
480 | }
481 | gmean = Math.pow(gmean, (double)1/numClasses);
482 | return gmean;
483 | }
484 |
485 | }
486 |
487 |
488 | public void reset(int numClasses) {
489 | this.numClasses = numClasses;
490 |
491 | this.aucEstimator = new Estimator(this.widthOption.getValue());
492 | this.totalObservedInstances = 0;
493 | }
494 |
495 | public void addResult(Instance inst, double[] classVotes) {
496 |
497 | double weight = inst.weight();
498 | int trueClass = (int) inst.classValue();
499 |
500 | if (weight > 0.0) {
501 | // // initialize evaluator
502 | if (totalObservedInstances == 0) {
503 | reset(inst.dataset().numClasses());
504 | }
505 | this.totalObservedInstances += 1;
506 |
507 | if(classVotes.length != numClasses) {
508 | classVotes = new double[this.numClasses];
509 | for(int i = 0; i < numClasses; i++)
510 | classVotes[i] = 1.0 / (double) numClasses;
511 | }
512 |
513 | double[] normalizedVote = classVotes.clone();//get a deep copy of classVotes
514 |
515 | for(int i = 0; i < normalizedVote.length; i++) {
516 | if(Double.isNaN(normalizedVote[i]))
517 | normalizedVote[i] = 0.0;
518 | if(Double.isInfinite(normalizedVote[i]))
519 | normalizedVote[i] = Double.MAX_VALUE;
520 | }
521 |
522 | if(Utils.sum(normalizedVote) == 0) { normalizedVote[0] = 1; }
523 |
524 | //// normalize and add score
525 | Utils.normalize(normalizedVote);
526 |
527 | this.aucEstimator.add(normalizedVote, trueClass, Utils.maxIndex(classVotes) == trueClass, Utils.maxIndex(classVotes));
528 | }
529 | }
530 |
531 | public Measurement[] getPerformanceMeasurements() {
532 | Measurement[] measurement = new Measurement[8 + numClasses*numClasses];
533 |
534 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances);
535 | measurement[1] = new Measurement("PMAUC", this.aucEstimator.getPMAUC());
536 | measurement[2] = new Measurement("WMAUC", this.aucEstimator.getWeightedAUC());
537 | measurement[3] = new Measurement("EWMAUC", this.aucEstimator.getEqualWeightedAUC());
538 | measurement[4] = new Measurement("Accuracy", this.aucEstimator.getAccuracy());
539 | measurement[5] = new Measurement("Kappa", this.aucEstimator.getKappa());
540 | measurement[6] = new Measurement("G-Mean", this.aucEstimator.getGmean());
541 | measurement[7] = new Measurement("Classes", numClasses);
542 |
543 | for(int i = 0; i < numClasses; i++) {
544 | for(int j = 0; j < numClasses; j++) {
545 | measurement[8 + i*numClasses + j] = new Measurement("CM["+i+"]["+j+"]", this.aucEstimator.confusionMatrix[i][j]);
546 | }
547 | }
548 |
549 | return measurement;
550 | }
551 |
552 | @Override
553 | public void getDescription(StringBuilder sb, int indent) {
554 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
555 | sb, indent);
556 | }
557 |
558 | @Override
559 | public void prepareForUseImpl(TaskMonitor monitor,
560 | ObjectRepository repository) {
561 | }
562 |
563 | public Estimator getAucEstimator() {
564 | return aucEstimator;
565 | }
566 |
567 |
568 | @Override
569 | public void reset() {
570 | reset(this.numClasses);
571 | }
572 |
573 |
574 | @Override
575 | public void addResult(Example example, double[] classVotes) {
576 | addResult(example.getData(), classVotes);
577 | }
578 |
579 | @Override
580 | public void addResult(Example testInst, Prediction prediction) {
581 | throw new RuntimeException("Designed for scoring classifiers");
582 | }
583 | }
--------------------------------------------------------------------------------
/src/main/java/moa/evaluation/WindowImbalancedClassificationPerformanceEvaluator.java:
--------------------------------------------------------------------------------
1 | package moa.evaluation;
2 |
3 | import moa.core.Example;
4 | import moa.core.Measurement;
5 | import moa.core.ObjectRepository;
6 | import moa.core.Utils;
7 | import moa.options.AbstractOptionHandler;
8 |
9 | import com.github.javacliparser.IntOption;
10 | import com.yahoo.labs.samoa.instances.Instance;
11 | import com.yahoo.labs.samoa.instances.InstanceImpl;
12 | import com.yahoo.labs.samoa.instances.Prediction;
13 |
14 | import moa.tasks.TaskMonitor;
15 |
16 | /**
17 | * Classification evaluator that updates evaluation results using a sliding window.
18 | * Performance measures designed for class imbalance problems (binary and multiclass).
19 | *
20 | * @author Alberto Cano
21 | */
22 | public class WindowImbalancedClassificationPerformanceEvaluator extends AbstractOptionHandler implements ClassificationPerformanceEvaluator {
23 |
24 | private static final long serialVersionUID = 1L;
25 |
26 | public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000);
27 |
28 | protected int confusionMatrix[][];
29 |
30 | protected int numClasses;
31 |
32 | protected int totalObservedInstances;
33 |
34 | protected int positionWindow;
35 |
36 | protected int predictionsTrue[];
37 |
38 | protected int predictionsPredicted[];
39 |
40 | @Override
41 | public void reset() {
42 | reset(this.numClasses);
43 | }
44 |
45 | public void reset(int numClasses) {
46 | this.numClasses = numClasses;
47 | this.positionWindow = 0;
48 | this.totalObservedInstances = 0;
49 | this.confusionMatrix = new int[numClasses][numClasses];
50 | this.predictionsTrue = new int[this.widthOption.getValue()];
51 | this.predictionsPredicted = new int[this.widthOption.getValue()];
52 | }
53 |
54 | @Override
55 | public void addResult(Example exampleInstance, double[] classVotes) {
56 | InstanceImpl inst = (InstanceImpl)exampleInstance.getData();
57 | double weight = inst.weight();
58 |
59 | if (inst.classIsMissing() == false){
60 | int trueClass = (int) inst.classValue();
61 | int predictedClass = Utils.maxIndex(classVotes);
62 |
63 | if (weight > 0.0) {
64 | // // initialize evaluator
65 | if (totalObservedInstances == 0) {
66 | reset(inst.dataset().numClasses());
67 | }
68 |
69 | this.totalObservedInstances++;
70 |
71 | if(totalObservedInstances > this.widthOption.getValue())
72 | {
73 | this.confusionMatrix[predictionsPredicted[positionWindow]][predictionsTrue[positionWindow]]--;
74 | }
75 |
76 | this.predictionsTrue[positionWindow] = trueClass;
77 | this.predictionsPredicted[positionWindow] = predictedClass;
78 | this.confusionMatrix[predictedClass][trueClass]++;
79 |
80 | positionWindow = (positionWindow + 1) % this.widthOption.getValue();
81 | }
82 | }
83 | }
84 |
85 | protected double Kappa(int[][] confusionMatrix)
86 | {
87 | int correctedClassified = 0;
88 | int numberInstancesTotal = 0;
89 | int[] numberInstances = new int[numClasses];
90 | int[] predictedInstances = new int[numClasses];
91 |
92 | for(int i = 0; i < numClasses; i++)
93 | {
94 | correctedClassified += confusionMatrix[i][i];
95 |
96 | for(int j = 0; j < numClasses; j++)
97 | {
98 | numberInstances[i] += confusionMatrix[j][i];
99 | predictedInstances[i] += confusionMatrix[i][j];
100 | }
101 |
102 | numberInstancesTotal += numberInstances[i];
103 | }
104 |
105 | double mul = 0;
106 |
107 | for(int i = 0; i < numClasses; i++)
108 | mul += numberInstances[i] * predictedInstances[i];
109 |
110 | if(numberInstancesTotal*numberInstancesTotal - mul != 0)
111 | return ((numberInstancesTotal * correctedClassified) - mul) / (double) ((numberInstancesTotal*numberInstancesTotal) - mul);
112 | else
113 | return 1.0;
114 | }
115 |
116 | protected double AUC(int[][] confusionMatrix)
117 | {
118 | if(numClasses == 2)
119 | return AUC(confusionMatrix,0,1); // Assumes 0 positive (minority), 1 negative (majority)
120 | else
121 | {
122 | /** Multi-class AUC **/
123 | double auc = 0.0;
124 |
125 | for(int i = 0; i < numClasses; i++)
126 | for(int j = 0; j < numClasses; j++)
127 | if(i != j)
128 | auc += AUC(confusionMatrix,i,j);
129 |
130 | return auc / (double) (numClasses * (numClasses-1));
131 | }
132 | }
133 |
134 | protected double AUC(int[][] confusionMatrix, int Class1, int Class2)
135 | {
136 | int tp = confusionMatrix[Class1][Class1];
137 | int fp = confusionMatrix[Class1][Class2];
138 | int tn = confusionMatrix[Class2][Class2];
139 | int fn = confusionMatrix[Class2][Class1];
140 |
141 | double tpRate = 1.0, fpRate = 0.0;
142 |
143 | if(tp + fn != 0)
144 | tpRate = tp / (double) (tp + fn);
145 |
146 | if(fp + tn != 0)
147 | fpRate = fp / (double) (fp + tn);
148 |
149 | double auc = (1.0 + tpRate - fpRate) / 2.0;
150 |
151 | return auc;
152 | }
153 |
154 | protected double GMean(int[][] confusionMatrix)
155 | {
156 | double gmean = 1.0;
157 |
158 | int[] numberInstances = new int[numClasses];
159 |
160 | for(int i = 0; i < numClasses; i++)
161 | {
162 | for(int j = 0; j < numClasses; j++)
163 | numberInstances[i] += confusionMatrix[j][i];
164 |
165 | if(numberInstances[i] != 0)
166 | gmean *= (confusionMatrix[i][i] / (double) numberInstances[i]);
167 | }
168 |
169 | return Math.pow(gmean, 1.0 / (double) numClasses);
170 | }
171 |
172 | protected double Accuracy(int[][] confusionMatrix)
173 | {
174 | int correctedClassified = 0;
175 | int numberInstancesTotal = 0;
176 |
177 | for(int i = 0; i < numClasses; i++)
178 | {
179 | correctedClassified += confusionMatrix[i][i];
180 |
181 | for(int j = 0; j < numClasses; j++)
182 | numberInstancesTotal += confusionMatrix[i][j];
183 | }
184 |
185 | return correctedClassified / (double) numberInstancesTotal;
186 | }
187 |
188 | protected double AvgAccuracy(int[][] confusionMatrix)
189 | {
190 | int[] numberInstances = new int[numClasses];
191 |
192 | for(int i = 0; i < numClasses; i++)
193 | for(int j = 0; j < numClasses; j++)
194 | numberInstances[i] += confusionMatrix[j][i];
195 |
196 | double avgAccuracy = 0;
197 | int existingClasses = 0;
198 |
199 | for(int i = 0; i < numClasses; i++)
200 | if(numberInstances[i] != 0)
201 | {
202 | avgAccuracy += confusionMatrix[i][i] / (double) numberInstances[i];
203 | existingClasses++;
204 | }
205 |
206 | return avgAccuracy / (double) existingClasses;
207 | }
208 |
209 | protected double Precision(int[][] confusionMatrix)
210 | {
211 | int[] numberPredictions = new int[numClasses];
212 |
213 | for(int i = 0; i < numClasses; i++)
214 | for(int j = 0; j < numClasses; j++)
215 | numberPredictions[i] += confusionMatrix[i][j];
216 |
217 | double precision = 0;
218 | int existingClasses = 0;
219 |
220 | for(int i = 0; i < numClasses; i++)
221 | if(numberPredictions[i] != 0)
222 | {
223 | precision += confusionMatrix[i][i] / (double) numberPredictions[i];
224 | existingClasses++;
225 | }
226 |
227 | return precision / (double) existingClasses;
228 | }
229 |
230 | protected double Recall(int[][] confusionMatrix)
231 | {
232 | int[] numberInstances = new int[numClasses];
233 |
234 | for(int i = 0; i < numClasses; i++)
235 | for(int j = 0; j < numClasses; j++)
236 | numberInstances[i] += confusionMatrix[j][i];
237 |
238 | double recall = 0;
239 | int existingClasses = 0;
240 |
241 | for(int i = 0; i < numClasses; i++)
242 | if(numberInstances[i] != 0)
243 | {
244 | recall += confusionMatrix[i][i] / (double) numberInstances[i];
245 | existingClasses++;
246 | }
247 |
248 | return recall / (double) existingClasses;
249 | }
250 |
251 | protected double Ratio(int[][] confusionMatrix, int targetClass)
252 | {
253 | int numberInstances = 0;
254 | int numberInstancesTotal = 0;
255 |
256 | for(int i = 0; i < numClasses; i++)
257 | {
258 | numberInstances += confusionMatrix[i][targetClass];
259 |
260 | for(int j = 0; j < numClasses; j++)
261 | numberInstancesTotal += confusionMatrix[i][j];
262 | }
263 |
264 | return numberInstances / (double) numberInstancesTotal;
265 | }
266 |
267 |
268 | public double getAccuracy() {
269 | return Accuracy(this.confusionMatrix);
270 | }
271 |
272 | public double getKappa() {
273 | return Kappa(this.confusionMatrix);
274 | }
275 |
276 | @Override
277 | public Measurement[] getPerformanceMeasurements() {
278 |
279 | Measurement[] measurement = new Measurement[8 + numClasses];
280 |
281 | measurement[0] = new Measurement("classified instances", this.totalObservedInstances);
282 | measurement[1] = new Measurement("Accuracy", Accuracy(this.confusionMatrix));
283 | measurement[2] = new Measurement("AvgAccuracy", AvgAccuracy(this.confusionMatrix));
284 | measurement[3] = new Measurement("AUC", AUC(this.confusionMatrix));
285 | measurement[4] = new Measurement("Kappa", Kappa(this.confusionMatrix));
286 | measurement[5] = new Measurement("G-Mean", GMean(this.confusionMatrix));
287 | measurement[6] = new Measurement("Precision", Precision(this.confusionMatrix));
288 | measurement[7] = new Measurement("Recall", Recall(this.confusionMatrix));
289 |
290 | for(int i = 0; i < numClasses; i++)
291 | measurement[8 + i] = new Measurement("Ratio-Class-" + i, Ratio(this.confusionMatrix, i));
292 |
293 | return measurement;
294 | }
295 |
296 | @Override
297 | public void getDescription(StringBuilder sb, int indent) {
298 | Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, indent);
299 | }
300 |
301 | @Override
302 | public void prepareForUseImpl(TaskMonitor monitor,
303 | ObjectRepository repository) {
304 | }
305 |
306 | @Override
307 | public void addResult(Example arg0, Prediction arg1) {
308 | throw new RuntimeException("Designed for scoring classifiers");
309 | }
310 | }
--------------------------------------------------------------------------------
/src/main/java/moa/streams/filters/AddNoiseFilterFeatures.java:
--------------------------------------------------------------------------------
1 | /*
2 | * AddNoiseFilter.java
3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.filters;
21 |
22 | import java.util.ArrayList;
23 | import java.util.Random;
24 |
25 | import moa.core.AutoExpandVector;
26 | import moa.core.DoubleVector;
27 | import moa.core.GaussianEstimator;
28 | import com.yahoo.labs.samoa.instances.InstancesHeader;
29 | import com.github.javacliparser.FloatOption;
30 | import com.github.javacliparser.IntOption;
31 | import com.yahoo.labs.samoa.instances.Instance;
32 |
33 | /**
34 | * Filter for adding random noise to examples in a stream.
35 | *
36 | * @author Alberto Cano (acano@vcu.edu)
37 | * @version $Revision: 1 $
38 | */
39 | public class AddNoiseFilterFeatures extends AbstractStreamFilter {
40 |
41 | @Override
42 | public String getPurposeString() {
43 | return "Adds random noise to features in a stream.";
44 | }
45 |
46 | private static final long serialVersionUID = 1L;
47 |
48 | public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random noise.", 1);
49 |
50 | public FloatOption attNoiseFractionOption = new FloatOption("attNoise", 'f', "Probability of an attribute to be disturbed", 0.2, 0.0, 1.0);
51 |
52 | public FloatOption attValueNoiseFractionOption = new FloatOption("attValueNoise", 'a', "The fraction of attribute values to disturb.", 0.1, 0.0, 1.0);
53 |
54 | protected Random random;
55 |
56 | protected AutoExpandVector attValObservers;
57 |
58 | protected boolean[] disturbeAttribute;
59 |
60 | @Override
61 | protected void restartImpl() {
62 | this.disturbeAttribute = null;
63 | this.random = new Random(this.randomSeedOption.getValue());
64 | this.attValObservers = new AutoExpandVector();
65 | }
66 |
67 | @Override
68 | public InstancesHeader getHeader() {
69 | return this.inputStream.getHeader();
70 | }
71 |
72 | public Instance filterInstance(Instance inst) {
73 |
74 | if(this.disturbeAttribute == null) {
75 | this.disturbeAttribute = new boolean[inst.numAttributes()-1];
76 |
77 | int numberAttributesDisturbed = (int) Math.ceil((inst.numAttributes()-1) * this.attNoiseFractionOption.getValue());
78 |
79 | ArrayList attributesPool = new ArrayList();
80 |
81 | for(int i = 0; i < inst.numAttributes()-1; i++) {
82 | attributesPool.add(i);
83 | }
84 |
85 | for(int i = 0; i < numberAttributesDisturbed; i++) {
86 | this.disturbeAttribute[attributesPool.remove(this.random.nextInt(attributesPool.size()))] = true;
87 | }
88 | }
89 |
90 | for (int i = 0; i < inst.numAttributes() -1 ; i++) {
91 | if(this.disturbeAttribute[i]) {
92 | double noiseFrac = this.attValueNoiseFractionOption.getValue();
93 |
94 | if (inst.attribute(i).isNominal()) {
95 | DoubleVector obs = (DoubleVector) this.attValObservers.get(i);
96 | if (obs == null) {
97 | obs = new DoubleVector();
98 | this.attValObservers.set(i, obs);
99 | }
100 | int originalVal = (int) inst.value(i);
101 | if (!inst.isMissing(i)) {
102 | obs.addToValue(originalVal, inst.weight());
103 | }
104 | if ((this.random.nextDouble() < noiseFrac) && (obs.numNonZeroEntries() > 1)) {
105 | do {
106 | inst.setValue(i, this.random.nextInt(obs.numValues()));
107 | } while (((int) inst.value(i) == originalVal) || (obs.getValue((int) inst.value(i)) == 0.0));
108 | }
109 | } else {
110 | GaussianEstimator obs = (GaussianEstimator) this.attValObservers.get(i);
111 | if (obs == null) {
112 | obs = new GaussianEstimator();
113 | this.attValObservers.set(i, obs);
114 | }
115 | obs.addObservation(inst.value(i), inst.weight());
116 | inst.setValue(i, inst.value(i) + this.random.nextGaussian() * obs.getStdDev() * noiseFrac);
117 | }
118 | }
119 | }
120 | return inst;
121 | }
122 |
123 | @Override
124 | public void getDescription(StringBuilder sb, int indent) {
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/AgrawalGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * AgrawalGenerator.java
3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.yahoo.labs.samoa.instances.Attribute;
23 | import com.yahoo.labs.samoa.instances.DenseInstance;
24 | import moa.core.FastVector;
25 | import com.yahoo.labs.samoa.instances.Instance;
26 | import com.yahoo.labs.samoa.instances.Instances;
27 |
28 | import java.util.Random;
29 | import moa.core.Example;
30 | import moa.core.InstanceExample;
31 |
32 | import com.yahoo.labs.samoa.instances.InstancesHeader;
33 | import moa.core.ObjectRepository;
34 | import moa.options.AbstractOptionHandler;
35 | import com.github.javacliparser.FlagOption;
36 | import com.github.javacliparser.FloatOption;
37 | import com.github.javacliparser.IntOption;
38 | import moa.streams.ExampleStream;
39 | import moa.streams.InstanceStream;
40 | import moa.tasks.TaskMonitor;
41 |
42 | /**
43 | * Stream generator for Agrawal dataset.
44 | * Generator described in paper:
45 | * Rakesh Agrawal, Tomasz Imielinksi, and Arun Swami,
46 | * "Database Mining: A Performance Perspective",
47 | * IEEE Transactions on Knowledge and Data Engineering,
48 | * 5(6), December 1993.
49 | *
50 | * Public C source code available at:
51 | *
52 | * http://www.almaden.ibm.com/cs/projects/iis/hdb/Projects/data_mining/datasets/syndata.html
53 | *
54 | * Notes:
55 | * The built in functions are based on the paper (page 924),
56 | * which turn out to be functions pred20 thru pred29 in the public C implementation.
57 | * Perturbation function works like C implementation rather than description in paper.
58 | *
59 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
60 | * @version $Revision: 7 $
61 | */
62 | public class AgrawalGenerator extends AbstractOptionHandler implements
63 | InstanceStream {
64 |
65 | @Override
66 | public String getPurposeString() {
67 | return "Generates one of ten different pre-defined loan functions.";
68 | }
69 |
70 | private static final long serialVersionUID = 1L;
71 |
72 | public IntOption functionOption = new IntOption("function", 'f',
73 | "Classification function used, as defined in the original paper.",
74 | 1, 1, 10);
75 |
76 | public IntOption instanceRandomSeedOption = new IntOption(
77 | "instanceRandomSeed", 'i',
78 | "Seed for random generation of instances.", 1);
79 |
80 | public FloatOption peturbFractionOption = new FloatOption("peturbFraction",
81 | 'p',
82 | "The amount of peturbation (noise) introduced to numeric values.",
83 | 0.05, 0.0, 1.0);
84 |
85 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
86 | "Percentage of minority class examples", 0.1, 0, 1);
87 |
88 | protected interface ClassFunction {
89 |
90 | public int determineClass(double salary, double commission, int age,
91 | int elevel, int car, int zipcode, double hvalue, int hyears,
92 | double loan);
93 | }
94 |
95 | protected static ClassFunction[] classificationFunctions = {
96 | // function 1
97 | new ClassFunction() {
98 |
99 | @Override
100 | public int determineClass(double salary, double commission,
101 | int age, int elevel, int car, int zipcode,
102 | double hvalue, int hyears, double loan) {
103 | return ((age < 40) || (60 <= age)) ? 0 : 1;
104 | }
105 | },
106 | // function 2
107 | new ClassFunction() {
108 |
109 | @Override
110 | public int determineClass(double salary, double commission,
111 | int age, int elevel, int car, int zipcode,
112 | double hvalue, int hyears, double loan) {
113 | if (age < 40) {
114 | return ((50000 <= salary) && (salary <= 100000)) ? 0
115 | : 1;
116 | } else if (age < 60) {// && age >= 40
117 | return ((75000 <= salary) && (salary <= 125000)) ? 0
118 | : 1;
119 | } else {// age >= 60
120 | return ((25000 <= salary) && (salary <= 75000)) ? 0 : 1;
121 | }
122 | }
123 | },
124 | // function 3
125 | new ClassFunction() {
126 |
127 | @Override
128 | public int determineClass(double salary, double commission,
129 | int age, int elevel, int car, int zipcode,
130 | double hvalue, int hyears, double loan) {
131 | if (age < 40) {
132 | return ((elevel == 0) || (elevel == 1)) ? 0 : 1;
133 | } else if (age < 60) { // && age >= 40
134 | return ((elevel == 1) || (elevel == 2) || (elevel == 3)) ? 0
135 | : 1;
136 | } else { // age >= 60
137 | return ((elevel == 2) || (elevel == 3) || (elevel == 4)) ? 0
138 | : 1;
139 | }
140 | }
141 | },
142 | // function 4
143 | new ClassFunction() {
144 |
145 | @Override
146 | public int determineClass(double salary, double commission,
147 | int age, int elevel, int car, int zipcode,
148 | double hvalue, int hyears, double loan) {
149 | if (age < 40) {
150 | if ((elevel == 0) || (elevel == 1)) {
151 | return ((25000 <= salary) && (salary <= 75000)) ? 0
152 | : 1;
153 | }
154 | return ((50000 <= salary) && (salary <= 100000)) ? 0
155 | : 1;
156 | } else if (age < 60) {// && age >= 40
157 | if ((elevel == 1) || (elevel == 2) || (elevel == 3)) {
158 | return ((50000 <= salary) && (salary <= 100000)) ? 0
159 | : 1;
160 | }
161 | return ((75000 <= salary) && (salary <= 125000)) ? 0
162 | : 1;
163 | } else {// age >= 60
164 | if ((elevel == 2) || (elevel == 3) || (elevel == 4)) {
165 | return ((50000 <= salary) && (salary <= 100000)) ? 0
166 | : 1;
167 | }
168 | return ((25000 <= salary) && (salary <= 75000)) ? 0 : 1;
169 | }
170 | }
171 | },
172 | // function 5
173 | new ClassFunction() {
174 |
175 | @Override
176 | public int determineClass(double salary, double commission,
177 | int age, int elevel, int car, int zipcode,
178 | double hvalue, int hyears, double loan) {
179 | if (age < 40) {
180 | if ((50000 <= salary) && (salary <= 100000)) {
181 | return ((100000 <= loan) && (loan <= 300000)) ? 0
182 | : 1;
183 | }
184 | return ((200000 <= loan) && (loan <= 400000)) ? 0 : 1;
185 | } else if (age < 60) {// && age >= 40
186 | if ((75000 <= salary) && (salary <= 125000)) {
187 | return ((200000 <= loan) && (loan <= 400000)) ? 0
188 | : 1;
189 | }
190 | return ((300000 <= loan) && (loan <= 500000)) ? 0 : 1;
191 | } else {// age >= 60
192 | if ((25000 <= salary) && (salary <= 75000)) {
193 | return ((300000 <= loan) && (loan <= 500000)) ? 0
194 | : 1;
195 | }
196 | return ((100000 <= loan) && (loan <= 300000)) ? 0 : 1;
197 | }
198 | }
199 | },
200 | // function 6
201 | new ClassFunction() {
202 |
203 | @Override
204 | public int determineClass(double salary, double commission,
205 | int age, int elevel, int car, int zipcode,
206 | double hvalue, int hyears, double loan) {
207 | double totalSalary = salary + commission;
208 | if (age < 40) {
209 | return ((50000 <= totalSalary) && (totalSalary <= 100000)) ? 0
210 | : 1;
211 | } else if (age < 60) {// && age >= 40
212 | return ((75000 <= totalSalary) && (totalSalary <= 125000)) ? 0
213 | : 1;
214 | } else {// age >= 60
215 | return ((25000 <= totalSalary) && (totalSalary <= 75000)) ? 0
216 | : 1;
217 | }
218 | }
219 | },
220 | // function 7
221 | new ClassFunction() {
222 |
223 | @Override
224 | public int determineClass(double salary, double commission,
225 | int age, int elevel, int car, int zipcode,
226 | double hvalue, int hyears, double loan) {
227 | double disposable = (2.0 * (salary + commission) / 3.0
228 | - loan / 5.0 - 20000.0);
229 | return disposable > 0 ? 0 : 1;
230 | }
231 | },
232 | // function 8
233 | new ClassFunction() {
234 |
235 | @Override
236 | public int determineClass(double salary, double commission,
237 | int age, int elevel, int car, int zipcode,
238 | double hvalue, int hyears, double loan) {
239 | double disposable = (2.0 * (salary + commission) / 3.0
240 | - 5000.0 * elevel - 20000.0);
241 | return disposable > 0 ? 0 : 1;
242 | }
243 | },
244 | // function 9
245 | new ClassFunction() {
246 |
247 | @Override
248 | public int determineClass(double salary, double commission,
249 | int age, int elevel, int car, int zipcode,
250 | double hvalue, int hyears, double loan) {
251 | double disposable = (2.0 * (salary + commission) / 3.0
252 | - 5000.0 * elevel - loan / 5.0 - 10000.0);
253 | return disposable > 0 ? 0 : 1;
254 | }
255 | },
256 | // function 10
257 | new ClassFunction() {
258 |
259 | @Override
260 | public int determineClass(double salary, double commission,
261 | int age, int elevel, int car, int zipcode,
262 | double hvalue, int hyears, double loan) {
263 | double equity = 0.0;
264 | if (hyears >= 20) {
265 | equity = hvalue * (hyears - 20.0) / 10.0;
266 | }
267 | double disposable = (2.0 * (salary + commission) / 3.0
268 | - 5000.0 * elevel + equity / 5.0 - 10000.0);
269 | return disposable > 0 ? 0 : 1;
270 | }
271 | }};
272 |
273 | protected InstancesHeader streamHeader;
274 |
275 | protected Random instanceRandom;
276 |
277 | protected boolean nextClassShouldBeZero;
278 |
279 | @Override
280 | protected void prepareForUseImpl(TaskMonitor monitor,
281 | ObjectRepository repository) {
282 | // generate header
283 | FastVector attributes = new FastVector();
284 | attributes.addElement(new Attribute("salary"));
285 | attributes.addElement(new Attribute("commission"));
286 | attributes.addElement(new Attribute("age"));
287 | FastVector elevelLabels = new FastVector();
288 | for (int i = 0; i < 5; i++) {
289 | elevelLabels.addElement("level" + i);
290 | }
291 | attributes.addElement(new Attribute("elevel", elevelLabels));
292 | FastVector carLabels = new FastVector();
293 | for (int i = 0; i < 20; i++) {
294 | carLabels.addElement("car" + (i + 1));
295 | }
296 | attributes.addElement(new Attribute("car", carLabels));
297 | FastVector zipCodeLabels = new FastVector();
298 | for (int i = 0; i < 9; i++) {
299 | zipCodeLabels.addElement("zipcode" + (i + 1));
300 | }
301 | attributes.addElement(new Attribute("zipcode", zipCodeLabels));
302 | attributes.addElement(new Attribute("hvalue"));
303 | attributes.addElement(new Attribute("hyears"));
304 | attributes.addElement(new Attribute("loan"));
305 | FastVector classLabels = new FastVector();
306 | classLabels.addElement("groupA");
307 | classLabels.addElement("groupB");
308 | attributes.addElement(new Attribute("class", classLabels));
309 | this.streamHeader = new InstancesHeader(new Instances(
310 | getCLICreationString(InstanceStream.class), attributes, 0));
311 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
312 | restart();
313 | }
314 |
315 | @Override
316 | public long estimatedRemainingInstances() {
317 | return -1;
318 | }
319 |
320 | @Override
321 | public InstancesHeader getHeader() {
322 | return this.streamHeader;
323 | }
324 |
325 | @Override
326 | public boolean hasMoreInstances() {
327 | return true;
328 | }
329 |
330 | @Override
331 | public boolean isRestartable() {
332 | return true;
333 | }
334 |
335 | @Override
336 | public InstanceExample nextInstance() {
337 | double salary = 0, commission = 0, hvalue = 0, loan = 0;
338 | int age = 0, elevel = 0, car = 0, zipcode = 0, hyears = 0, group = 0;
339 |
340 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
341 |
342 | do {
343 | // generate attributes
344 | salary = 20000.0 + 130000.0 * this.instanceRandom.nextDouble();
345 | commission = (salary >= 75000.0) ? 0
346 | : (10000.0 + 65000.0 * this.instanceRandom.nextDouble());
347 | // true to c implementation:
348 | // if (instanceRandom.nextDouble() < 0.5 && salary < 75000.0)
349 | // commission = 10000.0 + 65000.0 * instanceRandom.nextDouble();
350 | age = 20 + this.instanceRandom.nextInt(61);
351 | elevel = this.instanceRandom.nextInt(5);
352 | car = this.instanceRandom.nextInt(20);
353 | zipcode = this.instanceRandom.nextInt(9);
354 | hvalue = (9.0 - zipcode) * 100000.0
355 | * (0.5 + this.instanceRandom.nextDouble());
356 | hyears = 1 + this.instanceRandom.nextInt(30);
357 | loan = this.instanceRandom.nextDouble() * 500000.0;
358 | // determine class
359 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(salary, commission, age, elevel, car, zipcode, hvalue, hyears, loan);
360 | }while(group != label);
361 |
362 | // perturb values
363 | if (this.peturbFractionOption.getValue() > 0.0) {
364 | salary = perturbValue(salary, 20000, 150000);
365 | if (commission > 0) {
366 | commission = perturbValue(commission, 10000, 75000);
367 | }
368 | age = (int) Math.round(perturbValue(age, 20, 80));
369 | hvalue = perturbValue(hvalue, (9.0 - zipcode) * 100000.0, 0, 135000);
370 | hyears = (int) Math.round(perturbValue(hyears, 1, 30));
371 | loan = perturbValue(loan, 0, 500000);
372 | }
373 | // construct instance
374 | InstancesHeader header = getHeader();
375 | Instance inst = new DenseInstance(header.numAttributes());
376 | inst.setValue(0, salary);
377 | inst.setValue(1, commission);
378 | inst.setValue(2, age);
379 | inst.setValue(3, elevel);
380 | inst.setValue(4, car);
381 | inst.setValue(5, zipcode);
382 | inst.setValue(6, hvalue);
383 | inst.setValue(7, hyears);
384 | inst.setValue(8, loan);
385 | inst.setDataset(header);
386 | inst.setClassValue(group);
387 | return new InstanceExample(inst);
388 | }
389 |
390 | protected double perturbValue(double val, double min, double max) {
391 | return perturbValue(val, max - min, min, max);
392 | }
393 |
394 | protected double perturbValue(double val, double range, double min,
395 | double max) {
396 | val += range * (2.0 * (this.instanceRandom.nextDouble() - 0.5))
397 | * this.peturbFractionOption.getValue();
398 | if (val < min) {
399 | val = min;
400 | } else if (val > max) {
401 | val = max;
402 | }
403 | return val;
404 | }
405 |
406 | @Override
407 | public void restart() {
408 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
409 | this.nextClassShouldBeZero = false;
410 | }
411 |
412 | @Override
413 | public void getDescription(StringBuilder sb, int indent) {
414 | // TODO Auto-generated method stub
415 | }
416 | }
417 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/AssetNegotiationGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * AssetNegotiationGenerator.java
3 | * Copyright (C) 2016 Pontifícia Universidade Católica do Paraná, Curitiba, Brazil
4 | * @author Jean Paul Barddal (jean.barddal@ppgia.pucpr.br)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 |
21 | package moa.streams.generators.imbalanced;
22 |
23 | import com.github.javacliparser.FloatOption;
24 | import com.github.javacliparser.IntOption;
25 | import com.yahoo.labs.samoa.instances.Attribute;
26 | import com.yahoo.labs.samoa.instances.DenseInstance;
27 | import com.yahoo.labs.samoa.instances.Instance;
28 | import com.yahoo.labs.samoa.instances.Instances;
29 | import com.yahoo.labs.samoa.instances.InstancesHeader;
30 | import java.util.Arrays;
31 | import java.util.Random;
32 | import moa.core.FastVector;
33 | import moa.core.InstanceExample;
34 | import moa.core.ObjectRepository;
35 | import moa.options.AbstractOptionHandler;
36 | import moa.streams.InstanceStream;
37 | import moa.tasks.TaskMonitor;
38 |
39 | /**
40 | *
41 | * @author Jean Paul Barddal
42 | * @author Fabrício Enembreck
43 | *
44 | * @version 1.0
45 | * @see Originally discussed in F. Enembreck, B. C. Ávila, E. E. Scalabrin &
46 | * J-P. Barthès. LEARNING DRIFTING NEGOTIATIONS. In Applied Artificial
47 | * Intelligence: An International Journal. Volume 21, Issue 9, 2007. DOI:
48 | * 10.1080/08839510701526954
49 | * @see First used in the data stream configuration in J. P. Barddal, H. M.
50 | * Gomes, F. Enembreck, B. Pfahringer & A. Bifet. ON DYNAMIC FEATURE WEIGHTING
51 | * FOR FEATURE DRIFTING DATA STREAMS. In European Conference on Machine Learning
52 | * and Principles and Practice of Knowledge Discovery (ECML/PKDD'16). 2016.
53 | */
54 |
55 | public class AssetNegotiationGenerator
56 | extends AbstractOptionHandler
57 | implements InstanceStream {
58 |
59 | /*
60 | * OPTIONS
61 | */
62 | public IntOption functionOption = new IntOption("function", 'f',
63 | "Classification function used, as defined in the original paper.",
64 | 1, 1, 5);
65 | public FloatOption noisePercentage = new FloatOption("noise", 'n',
66 | "% of class noise.", 0.05, 0.0, 1.0f);
67 |
68 | public IntOption instanceRandomSeedOption = new IntOption(
69 | "instanceRandomSeed", 'i',
70 | "Seed for random generation of instances.", 1);
71 |
72 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
73 | "Percentage of minority class examples", 0.1, 0, 1);
74 |
75 | /*
76 | * INTERNALS
77 | */
78 | protected InstancesHeader streamHeader;
79 |
80 | protected Random instanceRandom;
81 |
82 | protected boolean nextClassShouldBeZero;
83 |
84 | protected ClassFunction classFunction;
85 |
86 |
87 | /*
88 | * FEATURE DEFINITIONS
89 | */
90 | protected static String colorValues[] = {"black",
91 | "blue",
92 | "cyan",
93 | "brown",
94 | "red",
95 | "green",
96 | "yellow",
97 | "magenta"};
98 |
99 | protected static String priceValues[] = {"veryLow",
100 | "low",
101 | "normal",
102 | "high",
103 | "veryHigh",
104 | "quiteHigh",
105 | "enormous",
106 | "non_salable"};
107 |
108 | protected static String paymentValues[] = {"0",
109 | "30",
110 | "60",
111 | "90",
112 | "120",
113 | "150",
114 | "180",
115 | "210",
116 | "240"};
117 |
118 | protected static String amountValues[] = {"veryLow",
119 | "low",
120 | "normal",
121 | "high",
122 | "veryHigh",
123 | "quiteHigh",
124 | "enormous",
125 | "non_ensured"};
126 |
127 | protected static String deliveryDelayValues[] = {"veryLow",
128 | "low",
129 | "normal",
130 | "high",
131 | "veryHigh"};
132 |
133 | protected static String classValues[] = {"interested", "notInterested"};
134 |
135 | /*
136 | * Labeling functions
137 | */
138 | protected interface ClassFunction {
139 |
140 | public int determineClass(String color,
141 | String price,
142 | String payment,
143 | String amount,
144 | String deliveryDelay);
145 |
146 | public Instance makeTrue(Instance intnc);
147 | }
148 |
149 | protected static ClassFunction concepts[] = {
150 | new ClassFunction() {
151 | Random r = new Random(Integer.MAX_VALUE);
152 |
153 | @Override
154 | public int determineClass(String color,
155 | String price,
156 | String payment,
157 | String amount,
158 | String deliveryDelay) {
159 | if ((price.equals("normal") && amount.equals("high")
160 | || (color.equals("brown") && price.equals("veryLow")
161 | && deliveryDelay.equals("high")))) {
162 | return indexOfValue("interested", classValues);
163 | }
164 | return indexOfValue("notInterested", classValues);
165 | }
166 |
167 | @Override
168 | public Instance makeTrue(Instance intnc) {
169 | int part = r.nextInt(2);
170 | if (part == 0) {
171 |
172 | intnc.setValue(1, indexOfValue("normal", priceValues));
173 | intnc.setValue(3, indexOfValue("high", amountValues));
174 | } else {
175 | intnc.setValue(0, indexOfValue("brown", colorValues));
176 | intnc.setValue(1, indexOfValue("veryLow", priceValues));
177 | intnc.setValue(4, indexOfValue("high", deliveryDelayValues));
178 | }
179 | intnc.setClassValue(indexOfValue("interested", classValues));
180 | return intnc;
181 | }
182 |
183 | },
184 | new ClassFunction() {
185 | Random r = new Random(Integer.MAX_VALUE);
186 |
187 | @Override
188 | public int determineClass(String color,
189 | String price,
190 | String payment,
191 | String amount,
192 | String deliveryDelay) {
193 | if (price.equals("high") && amount.equals("veryHigh")
194 | && deliveryDelay.equals("high")) {
195 | return indexOfValue("interested", classValues);
196 | }
197 | return indexOfValue("notInterested", classValues);
198 | }
199 |
200 | @Override
201 | public Instance makeTrue(Instance intnc) {
202 | intnc.setValue(1, indexOfValue("high", priceValues));
203 | intnc.setValue(3, indexOfValue("veryHigh", amountValues));
204 | intnc.setValue(4, indexOfValue("high", deliveryDelayValues));
205 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested"));
206 | return intnc;
207 | }
208 | },
209 | new ClassFunction() {
210 | Random r = new Random(Integer.MAX_VALUE);
211 |
212 | @Override
213 | public int determineClass(String color,
214 | String price,
215 | String payment,
216 | String amount,
217 | String deliveryDelay) {
218 | if ((price.equals("veryLow")
219 | && payment.equals("0") && amount.equals("high"))
220 | || (color.equals("red") && price.equals("low")
221 | && payment.equals("30"))) {
222 | return indexOfValue("interested", classValues);
223 | }
224 | return indexOfValue("notInterested", classValues);
225 | }
226 |
227 | @Override
228 | public Instance makeTrue(Instance intnc) {
229 | int part = r.nextInt(2);
230 | if (part == 0) {
231 | intnc.setValue(1, indexOfValue("veryLow", priceValues));
232 | intnc.setValue(2, indexOfValue("0", paymentValues));
233 | intnc.setValue(3, indexOfValue("high", amountValues));
234 | } else {
235 | intnc.setValue(0, indexOfValue("red", colorValues));
236 | intnc.setValue(1, indexOfValue("low", priceValues));
237 | intnc.setValue(2, indexOfValue("30", paymentValues));
238 | }
239 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested"));
240 | return intnc;
241 | }
242 | },
243 | new ClassFunction() {
244 | Random r = new Random(Integer.MAX_VALUE);
245 |
246 | @Override
247 | public int determineClass(String color,
248 | String price,
249 | String payment,
250 | String amount,
251 | String deliveryDelay) {
252 | if ((color.equals("black")
253 | && payment.equals("90")
254 | && deliveryDelay.equals("veryLow"))
255 | || (color.equals("magenta")
256 | && price.equals("high")
257 | && deliveryDelay.equals("veryLow"))) {
258 | return indexOfValue("interested", classValues);
259 | }
260 | return indexOfValue("notInterested", classValues);
261 | }
262 |
263 | @Override
264 | public Instance makeTrue(Instance intnc) {
265 | int part = r.nextInt(2);
266 | if (part == 0) {
267 | intnc.setValue(0, indexOfValue("black", colorValues));
268 | intnc.setValue(2, indexOfValue("90", paymentValues));
269 | intnc.setValue(4, indexOfValue("veryLow", deliveryDelayValues));
270 | } else {
271 | intnc.setValue(0, indexOfValue("magenta", colorValues));
272 | intnc.setValue(1, indexOfValue("high", priceValues));
273 | intnc.setValue(4, indexOfValue("veryLow", deliveryDelayValues));
274 | }
275 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested"));
276 | return intnc;
277 | }
278 | },
279 | new ClassFunction() {
280 | Random r = new Random(Integer.MAX_VALUE);
281 |
282 | @Override
283 | public int determineClass(String color,
284 | String price,
285 | String payment,
286 | String amount,
287 | String deliveryDelay) {
288 | if ((color.equals("blue")
289 | && payment.equals("60")
290 | && amount.equals("low")
291 | && deliveryDelay.equals("normal"))
292 | || (color.equals("cyan")
293 | && amount.equals("low")
294 | && deliveryDelay.equals("normal"))) {
295 | return indexOfValue("interested", classValues);
296 | }
297 | return indexOfValue("notInterested", classValues);
298 | }
299 |
300 | @Override
301 | public Instance makeTrue(Instance intnc) {
302 | int part = r.nextInt(2);
303 | if (part == 0) {
304 | intnc.setValue(0, indexOfValue("blue", colorValues));
305 | intnc.setValue(2, indexOfValue("60", paymentValues));
306 | intnc.setValue(3, indexOfValue("low", amountValues));
307 | intnc.setValue(4, indexOfValue("normal", deliveryDelayValues));
308 | } else {
309 | intnc.setValue(0, indexOfValue("cyan", colorValues));
310 | intnc.setValue(3, indexOfValue("low", amountValues));
311 | intnc.setValue(4, indexOfValue("normal", deliveryDelayValues));
312 | }
313 | intnc.setClassValue(Arrays.asList(classValues).indexOf("interested"));
314 | return intnc;
315 | }
316 | }
317 | };
318 |
319 | /*
320 | * Generator core
321 | */
322 |
323 | @Override
324 | public void getDescription(StringBuilder sb, int indent) {
325 | sb.append("Generates instances based on 5 different concept functions "
326 | + "that describe whether another agent is "
327 | + "interested or not in an item.");
328 | }
329 |
330 | @Override
331 | protected void prepareForUseImpl(TaskMonitor tm, ObjectRepository or) {
332 |
333 | classFunction = concepts[this.functionOption.getValue() - 1];
334 |
335 | FastVector attributes = new FastVector();
336 | attributes.addElement(new Attribute("color",
337 | Arrays.asList(colorValues)));
338 | attributes.addElement(new Attribute("price",
339 | Arrays.asList(priceValues)));
340 | attributes.addElement(new Attribute("payment",
341 | Arrays.asList(paymentValues)));
342 | attributes.addElement(new Attribute("amount",
343 | Arrays.asList(amountValues)));
344 | attributes.addElement(new Attribute("deliveryDelay",
345 | Arrays.asList(deliveryDelayValues)));
346 |
347 | this.instanceRandom = new Random(System.currentTimeMillis());
348 |
349 | FastVector classLabels = new FastVector();
350 | for (int i = 0; i < classValues.length; i++) {
351 | classLabels.addElement(classValues[i]);
352 | }
353 |
354 | attributes.addElement(new Attribute("class", classLabels));
355 | this.streamHeader = new InstancesHeader(new Instances(
356 | getCLICreationString(InstanceStream.class), attributes, 0));
357 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
358 |
359 | restart();
360 | }
361 |
362 | @Override
363 | public InstancesHeader getHeader() {
364 | return streamHeader;
365 | }
366 |
367 | @Override
368 | public long estimatedRemainingInstances() {
369 | return Integer.MAX_VALUE;
370 | }
371 |
372 | @Override
373 | public boolean hasMoreInstances() {
374 | return true;
375 | }
376 |
377 | @Override
378 | public InstanceExample nextInstance() {
379 | Instance instnc = null;
380 |
381 | int classValue = -1;
382 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
383 |
384 | do{
385 | //randomize indexes for new instance
386 | int indexColor = this.instanceRandom.nextInt(colorValues.length);
387 | int indexPrice = this.instanceRandom.nextInt(priceValues.length);
388 | int indexPayment = this.instanceRandom.nextInt(paymentValues.length);
389 | int indexAmount = this.instanceRandom.nextInt(amountValues.length);
390 | int indexDelivery = this.instanceRandom.nextInt(deliveryDelayValues.length);
391 | //retrieve values
392 | String color = colorValues[indexColor];
393 | String price = priceValues[indexPrice];
394 | String payment = paymentValues[indexPayment];
395 | String amount = amountValues[indexAmount];
396 | String delivery = deliveryDelayValues[indexDelivery];
397 | classValue = classFunction.determineClass(color, price, payment, amount, delivery);
398 |
399 | instnc = new DenseInstance(streamHeader.numAttributes());
400 | //set values
401 | instnc.setDataset(this.getHeader());
402 | instnc.setValue(0, Arrays.asList(colorValues).indexOf(color));
403 | instnc.setValue(1, Arrays.asList(priceValues).indexOf(price));
404 | instnc.setValue(2, Arrays.asList(paymentValues).indexOf(payment));
405 | instnc.setValue(3, Arrays.asList(amountValues).indexOf(amount));
406 | instnc.setValue(4, Arrays.asList(deliveryDelayValues).indexOf(delivery));
407 |
408 | instnc.setClassValue((int) classValue);
409 |
410 | }while(classValue != label);
411 |
412 | //add noise
413 | int newClassValue = addNoise((int) instnc.classValue());
414 | instnc.setClassValue(newClassValue);
415 | return new InstanceExample(instnc);
416 | }
417 |
418 | @Override
419 | public boolean isRestartable() {
420 | return true;
421 | }
422 |
423 | @Override
424 | public void restart() {
425 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
426 | this.nextClassShouldBeZero = false;
427 | }
428 |
429 | int addNoise(int classObtained) {
430 | if (this.instanceRandom.nextFloat() <= this.noisePercentage.getValue()) {
431 | classObtained = classObtained == 0 ? 1 : 0;
432 | }
433 | return classObtained;
434 | }
435 |
436 | private static int indexOfValue(String value, Object[] arr) {
437 | int index = Arrays.asList(arr).indexOf(value);
438 | return index;
439 | }
440 | }
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/HyperplaneGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * HyperplaneGenerator.java
3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import java.util.Arrays;
23 | import java.util.Random;
24 | import com.github.javacliparser.FloatOption;
25 | import com.github.javacliparser.IntOption;
26 | import moa.core.FastVector;
27 | import moa.core.InstanceExample;
28 | import moa.core.ObjectRepository;
29 | import moa.options.AbstractOptionHandler;
30 | import moa.streams.InstanceStream;
31 | import moa.tasks.TaskMonitor;
32 | import com.yahoo.labs.samoa.instances.Attribute;
33 | import com.yahoo.labs.samoa.instances.DenseInstance;
34 | import com.yahoo.labs.samoa.instances.Instance;
35 | import com.yahoo.labs.samoa.instances.Instances;
36 | import com.yahoo.labs.samoa.instances.InstancesHeader;
37 |
38 | /**
39 | * Stream generator for Hyperplane data stream.
40 | *
41 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
42 | * @version $Revision: 7 $
43 | */
44 | public class HyperplaneGenerator extends AbstractOptionHandler implements
45 | InstanceStream {
46 |
47 | @Override
48 | public String getPurposeString() {
49 | return "Generates a problem of predicting class of a rotating hyperplane.";
50 | }
51 |
52 | private static final long serialVersionUID = 1L;
53 |
54 | public IntOption instanceRandomSeedOption = new IntOption(
55 | "instanceRandomSeed", 'i',
56 | "Seed for random generation of instances.", 1);
57 |
58 | public IntOption numClassesOption = new IntOption("numClasses", 'c',
59 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
60 |
61 | public IntOption numAttsOption = new IntOption("numAtts", 'a',
62 | "The number of attributes to generate.", 10, 0, Integer.MAX_VALUE);
63 |
64 | public IntOption numDriftAttsOption = new IntOption("numDriftAtts", 'k',
65 | "The number of attributes with drift.", 2, 0, Integer.MAX_VALUE);
66 |
67 | public FloatOption magChangeOption = new FloatOption("magChange", 't',
68 | "Magnitude of the change for every example", 0.0, 0.0, 1.0);
69 |
70 | public IntOption noisePercentageOption = new IntOption("noisePercentage",
71 | 'n', "Percentage of noise to add to the data.", 5, 0, 100);
72 |
73 | public IntOption sigmaPercentageOption = new IntOption("sigmaPercentage",
74 | 's', "Percentage of probability that the direction of change is reversed.", 10, 0, 100);
75 |
76 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
77 | "Percentage of minority class examples", 0.1, 0, 1);
78 |
79 | protected InstancesHeader streamHeader;
80 |
81 | protected Random instanceRandom;
82 |
83 | protected double[] weights;
84 |
85 | protected int[] sigma;
86 |
87 | public int numberInstance;
88 |
89 | @Override
90 | protected void prepareForUseImpl(TaskMonitor monitor,
91 | ObjectRepository repository) {
92 | monitor.setCurrentActivity("Preparing hyperplane...", -1.0);
93 | generateHeader();
94 | restart();
95 | }
96 |
97 | protected void generateHeader() {
98 | FastVector attributes = new FastVector();
99 | for (int i = 0; i < this.numAttsOption.getValue(); i++) {
100 | attributes.addElement(new Attribute("att" + (i + 1)));
101 | }
102 |
103 | FastVector classLabels = new FastVector();
104 | for (int i = 0; i < this.numClassesOption.getValue(); i++) {
105 | classLabels.addElement("class" + (i + 1));
106 | }
107 | attributes.addElement(new Attribute("class", classLabels));
108 | this.streamHeader = new InstancesHeader(new Instances(
109 | getCLICreationString(InstanceStream.class), attributes, 0));
110 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
111 | }
112 |
113 | @Override
114 | public long estimatedRemainingInstances() {
115 | return -1;
116 | }
117 |
118 | @Override
119 | public InstancesHeader getHeader() {
120 | return this.streamHeader;
121 | }
122 |
123 | @Override
124 | public boolean hasMoreInstances() {
125 | return true;
126 | }
127 |
128 | @Override
129 | public boolean isRestartable() {
130 | return true;
131 | }
132 |
133 | @Override
134 | public InstanceExample nextInstance() {
135 |
136 | int numAtts = this.numAttsOption.getValue();
137 | double[] attVals = new double[numAtts + 1];
138 | double sum = 0.0;
139 | double sumWeights = 0.0;
140 | int classLabel = -1;
141 |
142 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
143 |
144 | do
145 | {
146 | sum = 0.0;
147 | sumWeights = 0.0;
148 |
149 | for (int i = 0; i < numAtts; i++) {
150 | attVals[i] = this.instanceRandom.nextDouble();
151 | sum += this.weights[i] * attVals[i];
152 | sumWeights += this.weights[i];
153 | }
154 |
155 | if (sum >= sumWeights * 0.5) {
156 | classLabel = 1;
157 | } else {
158 | classLabel = 0;
159 | }
160 | }while(classLabel != label);
161 |
162 | //Add Noise
163 | if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) {
164 | classLabel = (classLabel == 0 ? 1 : 0);
165 | }
166 |
167 | Instance inst = new DenseInstance(1.0, attVals);
168 | inst.setDataset(getHeader());
169 | inst.setClassValue(classLabel);
170 | addDrift();
171 | return new InstanceExample(inst);
172 | }
173 |
174 | private void addDrift() {
175 | for (int i = 0; i < this.numDriftAttsOption.getValue(); i++) {
176 | this.weights[i] += (double) ((double) sigma[i]) * ((double) this.magChangeOption.getValue());
177 | if (//this.weights[i] >= 1.0 || this.weights[i] <= 0.0 ||
178 | (1 + (this.instanceRandom.nextInt(100))) <= this.sigmaPercentageOption.getValue()) {
179 | this.sigma[i] *= -1;
180 | }
181 | }
182 | }
183 |
184 | @Override
185 | public void restart() {
186 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
187 | this.weights = new double[this.numAttsOption.getValue()];
188 | this.sigma = new int[this.numAttsOption.getValue()];
189 | for (int i = 0; i < this.numAttsOption.getValue(); i++) {
190 | this.weights[i] = this.instanceRandom.nextDouble();
191 | this.sigma[i] = (i < this.numDriftAttsOption.getValue() ? 1 : 0);
192 | }
193 | }
194 |
195 | @Override
196 | public void getDescription(StringBuilder sb, int indent) {
197 | // TODO Auto-generated method stub
198 | }
199 | }
200 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/MixedGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * MixedGenerator.java
3 | * Copyright (C) 2016 Instituto Federal de Pernambuco
4 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.github.javacliparser.FlagOption;
23 | import com.github.javacliparser.FloatOption;
24 | import com.github.javacliparser.IntOption;
25 | import com.yahoo.labs.samoa.instances.Attribute;
26 | import com.yahoo.labs.samoa.instances.DenseInstance;
27 | import com.yahoo.labs.samoa.instances.Instance;
28 | import com.yahoo.labs.samoa.instances.Instances;
29 | import com.yahoo.labs.samoa.instances.InstancesHeader;
30 | import java.util.ArrayList;
31 | import java.util.List;
32 | import java.util.Random;
33 | import moa.core.InstanceExample;
34 |
35 | import moa.core.ObjectRepository;
36 | import moa.options.AbstractOptionHandler;
37 | import moa.streams.InstanceStream;
38 | import moa.tasks.TaskMonitor;
39 |
40 | /**
41 | * Abrupt concept drift, boolean noise-free examples. Four relevant attributes,
42 | * two boolean attributes v,w and two numeric attributes from [0; 1]. The
43 | * examples are classified positive if two of three conditions are satisfied:
44 | * v,w, y < 0,5 + 0,3 sin(3 * PI * x). After each context change the
45 | * classification is reversed. Proposed by "Gama, Joao, et al. "Learning with
46 | * drift detection." Advances in artificial intelligence–SBIA 2004. Springer
47 | * Berlin Heidelberg, 2004. 286-295."
48 | *
49 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br)
50 | * @version $Revision: 1 $
51 | */
52 | public class MixedGenerator extends AbstractOptionHandler implements
53 | InstanceStream {
54 |
55 | public IntOption functionOption = new IntOption("function", 'f',
56 | "Classification function used, as defined in the original paper.",
57 | 1, 1, 2);
58 |
59 | public IntOption instanceRandomSeedOption = new IntOption(
60 | "instanceRandomSeed", 'i',
61 | "Seed for random generation of instances.", 1);
62 |
63 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
64 | "Percentage of minority class examples", 0.1, 0, 1);
65 |
66 | protected InstancesHeader streamHeader;
67 |
68 | protected Random instanceRandom;
69 |
70 | protected boolean nextClassShouldBeZero;
71 |
72 | protected interface ClassFunction {
73 |
74 | public int determineClass(double v, double w, double x, double y);
75 | }
76 |
77 | protected static ClassFunction[] classificationFunctions = {
78 | new ClassFunction() {
79 | @Override
80 | public int determineClass(double v, double w, double x, double y) {
81 | boolean z = y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x);
82 | if ((v == 1 && w == 1) || (v == 1 && z) || (w == 1 && z)) {
83 | return 0;
84 | } else {
85 | return 1;
86 | }
87 | }
88 | },
89 | new ClassFunction() {
90 | @Override
91 | public int determineClass(double v, double w, double x, double y) {
92 | boolean z = y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x);
93 | if ((v == 1 && w == 1) || (v == 1 && z) || (w == 1 && z)) {
94 | return 1;
95 | } else {
96 | return 0;
97 | }
98 | }
99 | },};
100 |
101 | @Override
102 | public void getDescription(StringBuilder sb, int indent) {
103 |
104 | }
105 |
106 | @Override
107 | public InstancesHeader getHeader() {
108 | return this.streamHeader;
109 | }
110 |
111 | @Override
112 | public long estimatedRemainingInstances() {
113 | return -1;
114 | }
115 |
116 | @Override
117 | public boolean hasMoreInstances() {
118 | return true;
119 | }
120 |
121 | @Override
122 | public InstanceExample nextInstance() {
123 | double v = 0, w = 0, x = 0, y = 0, group = 0;
124 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
125 |
126 | do {
127 | v = (this.instanceRandom.nextDouble() < 0.5) ? 0 : 1;
128 | w = (this.instanceRandom.nextDouble() < 0.5) ? 0 : 1;
129 | x = this.instanceRandom.nextDouble();
130 | y = this.instanceRandom.nextDouble();
131 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(v, w, x, y);
132 | }while(group != label);
133 |
134 | // construct instance
135 | InstancesHeader header = getHeader();
136 | Instance inst = new DenseInstance(header.numAttributes());
137 | inst.setValue(0, v);
138 | inst.setValue(1, w);
139 | inst.setValue(2, x);
140 | inst.setValue(3, y);
141 | inst.setDataset(header);
142 | inst.setClassValue(group);
143 | return new InstanceExample(inst);
144 | }
145 |
146 | @Override
147 | public boolean isRestartable() {
148 | return true;
149 | }
150 |
151 | @Override
152 | public void restart() {
153 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
154 | this.nextClassShouldBeZero = false;
155 | }
156 |
157 | @Override
158 | protected void prepareForUseImpl(TaskMonitor monitor,
159 | ObjectRepository repository) {
160 | List booleanLabels = new ArrayList();
161 | booleanLabels.add("0");
162 | booleanLabels.add("1");
163 |
164 | ArrayList attributes = new ArrayList();
165 | Attribute attribute1 = new Attribute("v", booleanLabels);
166 | Attribute attribute2 = new Attribute("w", booleanLabels);
167 |
168 | Attribute attribute3 = new Attribute("x");
169 | Attribute attribute4 = new Attribute("y");
170 |
171 | List classLabels = new ArrayList();
172 | classLabels.add("positive");
173 | classLabels.add("negative");
174 | Attribute classAtt = new Attribute("class", classLabels);
175 |
176 | attributes.add(attribute1);
177 | attributes.add(attribute2);
178 | attributes.add(attribute3);
179 | attributes.add(attribute4);
180 | attributes.add(classAtt);
181 |
182 | this.streamHeader = new InstancesHeader(new Instances(
183 | getCLICreationString(InstanceStream.class), attributes, 0));
184 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
185 | restart();
186 | }
187 | }
188 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/RandomRBFGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * RandomRBFGenerator.java
3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.yahoo.labs.samoa.instances.Attribute;
23 | import com.yahoo.labs.samoa.instances.DenseInstance;
24 | import moa.core.FastVector;
25 | import com.yahoo.labs.samoa.instances.Instance;
26 | import com.yahoo.labs.samoa.instances.Instances;
27 |
28 | import java.io.Serializable;
29 | import java.util.Random;
30 | import moa.core.InstanceExample;
31 |
32 | import com.yahoo.labs.samoa.instances.InstancesHeader;
33 | import moa.core.MiscUtils;
34 | import moa.core.ObjectRepository;
35 | import moa.options.AbstractOptionHandler;
36 |
37 | import com.github.javacliparser.FloatOption;
38 | import com.github.javacliparser.IntOption;
39 | import moa.streams.InstanceStream;
40 | import moa.tasks.TaskMonitor;
41 |
42 | /**
43 | * Stream generator for a random radial basis function stream.
44 | *
45 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
46 | * @version $Revision: 7 $
47 | */
48 | public class RandomRBFGenerator extends AbstractOptionHandler implements
49 | InstanceStream {
50 |
51 | @Override
52 | public String getPurposeString() {
53 | return "Generates a random radial basis function stream.";
54 | }
55 |
56 | private static final long serialVersionUID = 1L;
57 |
58 | public IntOption modelRandomSeedOption = new IntOption("modelRandomSeed",
59 | 'r', "Seed for random generation of model.", 1);
60 |
61 | public IntOption instanceRandomSeedOption = new IntOption(
62 | "instanceRandomSeed", 'i',
63 | "Seed for random generation of instances.", 1);
64 |
65 | public IntOption numClassesOption = new IntOption("numClasses", 'c',
66 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
67 |
68 | public IntOption numAttsOption = new IntOption("numAtts", 'a',
69 | "The number of attributes to generate.", 10, 0, Integer.MAX_VALUE);
70 |
71 | public IntOption numCentroidsOption = new IntOption("numCentroids", 'n',
72 | "The number of centroids in the model.", 50, 1, Integer.MAX_VALUE);
73 |
74 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
75 | "Percentage of minority class examples", 0.1, 0, 1);
76 |
77 | protected static class Centroid implements Serializable {
78 |
79 | private static final long serialVersionUID = 1L;
80 |
81 | public double[] centre;
82 |
83 | public int classLabel;
84 |
85 | public double stdDev;
86 | }
87 |
88 | protected InstancesHeader streamHeader;
89 |
90 | protected Centroid[] centroids;
91 |
92 | protected double[] centroidWeights;
93 |
94 | protected Random instanceRandom;
95 |
96 | @Override
97 | public void prepareForUseImpl(TaskMonitor monitor,
98 | ObjectRepository repository) {
99 | monitor.setCurrentActivity("Preparing random RBF...", -1.0);
100 | generateHeader();
101 | generateCentroids();
102 | restart();
103 | }
104 |
105 | @Override
106 | public InstancesHeader getHeader() {
107 | return this.streamHeader;
108 | }
109 |
110 | @Override
111 | public long estimatedRemainingInstances() {
112 | return -1;
113 | }
114 |
115 | @Override
116 | public boolean hasMoreInstances() {
117 | return true;
118 | }
119 |
120 | @Override
121 | public boolean isRestartable() {
122 | return true;
123 | }
124 |
125 | @Override
126 | public void restart() {
127 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
128 | }
129 |
130 | @Override
131 | public InstanceExample nextInstance() {
132 | Centroid centroid;
133 |
134 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
135 |
136 | do{
137 | centroid = this.centroids[MiscUtils.chooseRandomIndexBasedOnWeights(this.centroidWeights, this.instanceRandom)];
138 | }while(centroid.classLabel != label);
139 |
140 | int numAtts = this.numAttsOption.getValue();
141 | double[] attVals = new double[numAtts + 1];
142 | for (int i = 0; i < numAtts; i++) {
143 | attVals[i] = (this.instanceRandom.nextDouble() * 2.0) - 1.0;
144 | }
145 | double magnitude = 0.0;
146 | for (int i = 0; i < numAtts; i++) {
147 | magnitude += attVals[i] * attVals[i];
148 | }
149 | magnitude = Math.sqrt(magnitude);
150 | double desiredMag = this.instanceRandom.nextGaussian() * centroid.stdDev;
151 | double scale = desiredMag / magnitude;
152 | for (int i = 0; i < numAtts; i++) {
153 | attVals[i] = centroid.centre[i] + attVals[i] * scale;
154 | }
155 | Instance inst = new DenseInstance(1.0, attVals);
156 | inst.setDataset(getHeader());
157 | inst.setClassValue(centroid.classLabel);
158 | return new InstanceExample(inst);
159 | }
160 |
161 | protected void generateHeader() {
162 | FastVector attributes = new FastVector();
163 | for (int i = 0; i < this.numAttsOption.getValue(); i++) {
164 | attributes.addElement(new Attribute("att" + (i + 1)));
165 | }
166 | FastVector classLabels = new FastVector();
167 | for (int i = 0; i < this.numClassesOption.getValue(); i++) {
168 | classLabels.addElement("class" + (i + 1));
169 | }
170 | attributes.addElement(new Attribute("class", classLabels));
171 | this.streamHeader = new InstancesHeader(new Instances(
172 | getCLICreationString(InstanceStream.class), attributes, 0));
173 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
174 | }
175 |
176 | protected void generateCentroids() {
177 | Random modelRand = new Random(this.modelRandomSeedOption.getValue());
178 | this.centroids = new Centroid[this.numCentroidsOption.getValue()];
179 | this.centroidWeights = new double[this.centroids.length];
180 | for (int i = 0; i < this.centroids.length; i++) {
181 | this.centroids[i] = new Centroid();
182 | double[] randCentre = new double[this.numAttsOption.getValue()];
183 | for (int j = 0; j < randCentre.length; j++) {
184 | randCentre[j] = modelRand.nextDouble();
185 | }
186 | this.centroids[i].centre = randCentre;
187 | this.centroids[i].classLabel = modelRand.nextInt(this.numClassesOption.getValue());
188 | this.centroids[i].stdDev = modelRand.nextDouble();
189 | this.centroidWeights[i] = modelRand.nextDouble();
190 | }
191 | }
192 |
193 | @Override
194 | public void getDescription(StringBuilder sb, int indent) {
195 | // TODO Auto-generated method stub
196 | }
197 | }
198 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/RandomRBFGeneratorDrift.java:
--------------------------------------------------------------------------------
1 | /*
2 | * RandomRBFGeneratorDrift.java
3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import java.util.Random;
23 | import moa.core.InstanceExample;
24 |
25 | import com.github.javacliparser.IntOption;
26 | import com.github.javacliparser.FloatOption;
27 | import com.yahoo.labs.samoa.instances.Instance;
28 |
29 | /**
30 | * Stream generator for a random radial basis function stream with drift.
31 | *
32 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
33 | * @version $Revision: 7 $
34 | */
35 | public class RandomRBFGeneratorDrift extends RandomRBFGenerator {
36 |
37 | @Override
38 | public String getPurposeString() {
39 | return "Generates a random radial basis function stream with drift.";
40 | }
41 |
42 | private static final long serialVersionUID = 1L;
43 |
44 | public FloatOption speedChangeOption = new FloatOption("speedChange", 's',
45 | "Speed of change of centroids in the model.", 0, 0, Float.MAX_VALUE);
46 |
47 | public IntOption numDriftCentroidsOption = new IntOption("numDriftCentroids", 'k',
48 | "The number of centroids with drift.", 50, 0, Integer.MAX_VALUE);
49 |
50 | protected double[][] speedCentroids;
51 |
52 | @Override
53 | public InstanceExample nextInstance() {
54 | //Update Centroids with drift
55 | int len = this.numDriftCentroidsOption.getValue();
56 | if (len > this.centroids.length) {
57 | len = this.centroids.length;
58 | }
59 | for (int j = 0; j < len; j++) {
60 | for (int i = 0; i < this.numAttsOption.getValue(); i++) {
61 | this.centroids[j].centre[i] += this.speedCentroids[j][i] * this.speedChangeOption.getValue();
62 | if (this.centroids[j].centre[i] > 1) {
63 | this.centroids[j].centre[i] = 1;
64 | this.speedCentroids[j][i] = -this.speedCentroids[j][i];
65 | }
66 | if (this.centroids[j].centre[i] < 0) {
67 | this.centroids[j].centre[i] = 0;
68 | this.speedCentroids[j][i] = -this.speedCentroids[j][i];
69 | }
70 | }
71 | }
72 | return super.nextInstance();
73 | }
74 |
75 | @Override
76 | protected void generateCentroids() {
77 | super.generateCentroids();
78 | Random modelRand = new Random(this.modelRandomSeedOption.getValue());
79 | int len = this.numDriftCentroidsOption.getValue();
80 | if (len > this.centroids.length) {
81 | len = this.centroids.length;
82 | }
83 | this.speedCentroids = new double[len][this.numAttsOption.getValue()];
84 | for (int i = 0; i < len; i++) {
85 | double[] randSpeed = new double[this.numAttsOption.getValue()];
86 | double normSpeed = 0.0;
87 | for (int j = 0; j < randSpeed.length; j++) {
88 | randSpeed[j] = modelRand.nextDouble();
89 | normSpeed += randSpeed[j] * randSpeed[j];
90 | }
91 | normSpeed = Math.sqrt(normSpeed);
92 | for (int j = 0; j < randSpeed.length; j++) {
93 | randSpeed[j] /= normSpeed;
94 | }
95 | this.speedCentroids[i] = randSpeed;
96 | }
97 | }
98 |
99 | @Override
100 | public void getDescription(StringBuilder sb, int indent) {
101 | // TODO Auto-generated method stub
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/RandomTreeGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * RandomTreeGenerator.java
3 | * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
4 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.yahoo.labs.samoa.instances.Attribute;
23 | import com.yahoo.labs.samoa.instances.DenseInstance;
24 | import moa.core.FastVector;
25 | import com.yahoo.labs.samoa.instances.Instance;
26 | import com.yahoo.labs.samoa.instances.Instances;
27 |
28 | import java.io.Serializable;
29 | import java.util.ArrayList;
30 | import java.util.Random;
31 | import moa.core.InstanceExample;
32 |
33 | import com.yahoo.labs.samoa.instances.InstancesHeader;
34 | import moa.core.ObjectRepository;
35 | import moa.options.AbstractOptionHandler;
36 | import com.github.javacliparser.FloatOption;
37 | import com.github.javacliparser.IntOption;
38 | import moa.streams.InstanceStream;
39 | import moa.tasks.TaskMonitor;
40 |
41 | /**
42 | * Stream generator for a stream based on a randomly generated tree..
43 | *
44 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
45 | * @version $Revision: 7 $
46 | */
47 | public class RandomTreeGenerator extends AbstractOptionHandler implements
48 | InstanceStream {
49 |
50 | @Override
51 | public String getPurposeString() {
52 | return "Generates a stream based on a randomly generated tree.";
53 | }
54 |
55 | private static final long serialVersionUID = 1L;
56 |
57 | public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed",
58 | 'r', "Seed for random generation of tree.", 1);
59 |
60 | public IntOption instanceRandomSeedOption = new IntOption(
61 | "instanceRandomSeed", 'i',
62 | "Seed for random generation of instances.", 1);
63 |
64 | public IntOption numClassesOption = new IntOption("numClasses", 'c',
65 | "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
66 |
67 | public IntOption numNominalsOption = new IntOption("numNominals", 'o',
68 | "The number of nominal attributes to generate.", 5, 0,
69 | Integer.MAX_VALUE);
70 |
71 | public IntOption numNumericsOption = new IntOption("numNumerics", 'u',
72 | "The number of numeric attributes to generate.", 5, 0,
73 | Integer.MAX_VALUE);
74 |
75 | public IntOption numValsPerNominalOption = new IntOption(
76 | "numValsPerNominal", 'v',
77 | "The number of values to generate per nominal attribute.", 5, 2,
78 | Integer.MAX_VALUE);
79 |
80 | public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd',
81 | "The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE);
82 |
83 | public IntOption firstLeafLevelOption = new IntOption(
84 | "firstLeafLevel",
85 | 'l',
86 | "The first level of the tree above maxTreeDepth that can have leaves.",
87 | 3, 0, Integer.MAX_VALUE);
88 |
89 | public FloatOption leafFractionOption = new FloatOption("leafFraction",
90 | 'f',
91 | "The fraction of leaves per level from firstLeafLevel onwards.",
92 | 0.15, 0.0, 1.0);
93 |
94 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
95 | "Percentage of minority class examples", 0.1, 0, 1);
96 |
97 | protected static class Node implements Serializable {
98 |
99 | private static final long serialVersionUID = 1L;
100 |
101 | public int classLabel;
102 |
103 | public int splitAttIndex;
104 |
105 | public double splitAttValue;
106 |
107 | public Node[] children;
108 | }
109 |
110 | protected Node treeRoot;
111 |
112 | protected InstancesHeader streamHeader;
113 |
114 | protected Random instanceRandom;
115 |
116 | @Override
117 | public void prepareForUseImpl(TaskMonitor monitor,
118 | ObjectRepository repository) {
119 | monitor.setCurrentActivity("Preparing random tree...", -1.0);
120 | generateHeader();
121 | generateRandomTree();
122 | restart();
123 | }
124 |
125 | @Override
126 | public long estimatedRemainingInstances() {
127 | return -1;
128 | }
129 |
130 | @Override
131 | public boolean isRestartable() {
132 | return true;
133 | }
134 |
135 | @Override
136 | public void restart() {
137 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
138 | }
139 |
140 | @Override
141 | public InstancesHeader getHeader() {
142 | return this.streamHeader;
143 | }
144 |
145 | @Override
146 | public boolean hasMoreInstances() {
147 | return true;
148 | }
149 |
150 | @Override
151 | public InstanceExample nextInstance() {
152 | double[] attVals = new double[this.numNominalsOption.getValue() + this.numNumericsOption.getValue()];
153 | InstancesHeader header = getHeader();
154 | Instance inst = new DenseInstance(header.numAttributes());
155 | inst.setDataset(header);
156 |
157 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
158 |
159 | do{
160 | for (int i = 0; i < attVals.length; i++) {
161 | attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption.getValue()) : this.instanceRandom.nextDouble();
162 | inst.setValue(i, attVals[i]);
163 | }
164 | inst.setClassValue(classifyInstance(this.treeRoot, attVals));
165 | }while(label != inst.classValue());
166 |
167 | return new InstanceExample(inst);
168 | }
169 |
170 | protected int classifyInstance(Node node, double[] attVals) {
171 | if (node.children == null) {
172 | return node.classLabel;
173 | }
174 | if (node.splitAttIndex < this.numNominalsOption.getValue()) {
175 | return classifyInstance(
176 | node.children[(int) attVals[node.splitAttIndex]], attVals);
177 | }
178 | return classifyInstance(
179 | node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0
180 | : 1], attVals);
181 | }
182 |
183 | protected void generateHeader() {
184 | FastVector attributes = new FastVector();
185 | FastVector nominalAttVals = new FastVector();
186 | for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) {
187 | nominalAttVals.addElement("value" + (i + 1));
188 | }
189 | for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
190 | attributes.addElement(new Attribute("nominal" + (i + 1),
191 | nominalAttVals));
192 | }
193 | for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
194 | attributes.addElement(new Attribute("numeric" + (i + 1)));
195 | }
196 | FastVector classLabels = new FastVector();
197 | for (int i = 0; i < this.numClassesOption.getValue(); i++) {
198 | classLabels.addElement("class" + (i + 1));
199 | }
200 | attributes.addElement(new Attribute("class", classLabels));
201 | this.streamHeader = new InstancesHeader(new Instances(
202 | getCLICreationString(InstanceStream.class), attributes, 0));
203 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
204 | }
205 |
206 | protected void generateRandomTree() {
207 | Random treeRand = new Random(this.treeRandomSeedOption.getValue());
208 | ArrayList nominalAttCandidates = new ArrayList(
209 | this.numNominalsOption.getValue());
210 | for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
211 | nominalAttCandidates.add(i);
212 | }
213 | double[] minNumericVals = new double[this.numNumericsOption.getValue()];
214 | double[] maxNumericVals = new double[this.numNumericsOption.getValue()];
215 | for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
216 | minNumericVals[i] = 0.0;
217 | maxNumericVals[i] = 1.0;
218 | }
219 | this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates,
220 | minNumericVals, maxNumericVals, treeRand);
221 | }
222 |
223 | protected Node generateRandomTreeNode(int currentDepth,
224 | ArrayList nominalAttCandidates, double[] minNumericVals,
225 | double[] maxNumericVals, Random treeRand) {
226 | if ((currentDepth >= this.maxTreeDepthOption.getValue())
227 | || ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand.nextDouble())))) {
228 | Node leaf = new Node();
229 | leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue());
230 | return leaf;
231 | }
232 | Node node = new Node();
233 | int chosenAtt = treeRand.nextInt(nominalAttCandidates.size()
234 | + this.numNumericsOption.getValue());
235 | if (chosenAtt < nominalAttCandidates.size()) {
236 | node.splitAttIndex = nominalAttCandidates.get(chosenAtt);
237 | node.children = new Node[this.numValsPerNominalOption.getValue()];
238 | ArrayList newNominalCandidates = new ArrayList(
239 | nominalAttCandidates);
240 | newNominalCandidates.remove(new Integer(node.splitAttIndex));
241 | newNominalCandidates.trimToSize();
242 | for (int i = 0; i < node.children.length; i++) {
243 | node.children[i] = generateRandomTreeNode(currentDepth + 1,
244 | newNominalCandidates, minNumericVals, maxNumericVals,
245 | treeRand);
246 | }
247 | } else {
248 | int numericIndex = chosenAtt - nominalAttCandidates.size();
249 | node.splitAttIndex = this.numNominalsOption.getValue()
250 | + numericIndex;
251 | double minVal = minNumericVals[numericIndex];
252 | double maxVal = maxNumericVals[numericIndex];
253 | node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble())
254 | + minVal;
255 | node.children = new Node[2];
256 | double[] newMaxVals = maxNumericVals.clone();
257 | newMaxVals[numericIndex] = node.splitAttValue;
258 | node.children[0] = generateRandomTreeNode(currentDepth + 1,
259 | nominalAttCandidates, minNumericVals, newMaxVals, treeRand);
260 | double[] newMinVals = minNumericVals.clone();
261 | newMinVals[numericIndex] = node.splitAttValue;
262 | node.children[1] = generateRandomTreeNode(currentDepth + 1,
263 | nominalAttCandidates, newMinVals, maxNumericVals, treeRand);
264 | }
265 | return node;
266 | }
267 |
268 | @Override
269 | public void getDescription(StringBuilder sb, int indent) {
270 | // TODO Auto-generated method stub
271 | }
272 | }
273 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/SEAGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * SEAGenerator.java
3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.yahoo.labs.samoa.instances.Attribute;
23 | import com.yahoo.labs.samoa.instances.DenseInstance;
24 | import moa.core.FastVector;
25 | import com.yahoo.labs.samoa.instances.Instance;
26 | import com.yahoo.labs.samoa.instances.Instances;
27 |
28 | import java.util.Random;
29 | import moa.core.InstanceExample;
30 |
31 | import com.yahoo.labs.samoa.instances.InstancesHeader;
32 | import moa.core.ObjectRepository;
33 | import moa.options.AbstractOptionHandler;
34 | import com.github.javacliparser.FlagOption;
35 | import com.github.javacliparser.FloatOption;
36 | import com.github.javacliparser.IntOption;
37 | import moa.streams.InstanceStream;
38 | import moa.tasks.TaskMonitor;
39 |
40 | /**
41 | * Stream generator for SEA concepts functions.
42 | * Generator described in the paper:
43 | * W. Nick Street and YongSeog Kim
44 | * "A streaming ensemble algorithm (SEA) for large-scale classification",
45 | * KDD '01: Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining
46 | * 377-382 2001.
47 | *
48 | * Notes:
49 | * The built in functions are based on the paper.
50 | *
51 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
52 | * @version $Revision: 7 $
53 | */
54 | public class SEAGenerator extends AbstractOptionHandler implements
55 | InstanceStream {
56 |
57 | @Override
58 | public String getPurposeString() {
59 | return "Generates SEA concepts functions.";
60 | }
61 |
62 | private static final long serialVersionUID = 1L;
63 |
64 | public IntOption functionOption = new IntOption("function", 'f',
65 | "Classification function used, as defined in the original paper.",
66 | 1, 1, 4);
67 |
68 | public IntOption instanceRandomSeedOption = new IntOption(
69 | "instanceRandomSeed", 'i',
70 | "Seed for random generation of instances.", 1);
71 |
72 | public IntOption numInstancesConcept = new IntOption("numInstancesConcept", 'n',
73 | "The number of instances for each concept.", 0, 0, Integer.MAX_VALUE);
74 |
75 | public IntOption noisePercentageOption = new IntOption("noisePercentage",
76 | 'p', "Percentage of noise to add to the data.", 10, 0, 100);
77 |
78 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
79 | "Percentage of minority class examples", 0.1, 0, 1);
80 |
81 | protected interface ClassFunction {
82 |
83 | public int determineClass(double attrib1, double attrib2, double attrib3);
84 | }
85 |
86 | protected static ClassFunction[] classificationFunctions = {
87 | // function 1
88 | new ClassFunction() {
89 |
90 | @Override
91 | public int determineClass(double attrib1, double attrib2, double attrib3) {
92 | return (attrib1 + attrib2 <= 8) ? 0 : 1;
93 | }
94 | },
95 | // function 2
96 | new ClassFunction() {
97 |
98 | @Override
99 | public int determineClass(double attrib1, double attrib2, double attrib3) {
100 | return (attrib1 + attrib2 <= 9) ? 0 : 1;
101 | }
102 | },
103 | // function 3
104 | new ClassFunction() {
105 |
106 | public int determineClass(double attrib1, double attrib2, double attrib3) {
107 | return (attrib1 + attrib2 <= 7) ? 0 : 1;
108 | }
109 | },
110 | // function 4
111 | new ClassFunction() {
112 |
113 | @Override
114 | public int determineClass(double attrib1, double attrib2, double attrib3) {
115 | return (attrib1 + attrib2 <= 9.5) ? 0 : 1;
116 | }
117 | }
118 | };
119 |
120 | protected InstancesHeader streamHeader;
121 |
122 | protected Random instanceRandom;
123 |
124 | protected boolean nextClassShouldBeZero;
125 |
126 | @Override
127 | protected void prepareForUseImpl(TaskMonitor monitor,
128 | ObjectRepository repository) {
129 | // generate header
130 | FastVector attributes = new FastVector();
131 | attributes.addElement(new Attribute("attrib1"));
132 | attributes.addElement(new Attribute("attrib2"));
133 | attributes.addElement(new Attribute("attrib3"));
134 |
135 | FastVector classLabels = new FastVector();
136 | classLabels.addElement("groupA");
137 | classLabels.addElement("groupB");
138 | attributes.addElement(new Attribute("class", classLabels));
139 | this.streamHeader = new InstancesHeader(new Instances(
140 | getCLICreationString(InstanceStream.class), attributes, 0));
141 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
142 | restart();
143 | }
144 |
145 | @Override
146 | public long estimatedRemainingInstances() {
147 | return -1;
148 | }
149 |
150 | @Override
151 | public InstancesHeader getHeader() {
152 | return this.streamHeader;
153 | }
154 |
155 | @Override
156 | public boolean hasMoreInstances() {
157 | return true;
158 | }
159 |
160 | @Override
161 | public boolean isRestartable() {
162 | return true;
163 | }
164 |
165 | @Override
166 | public InstanceExample nextInstance() {
167 | double attrib1 = 0, attrib2 = 0, attrib3 = 0;
168 | int group = 0;
169 |
170 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
171 |
172 | do {
173 | // generate attributes
174 | attrib1 = 10 * this.instanceRandom.nextDouble();
175 | attrib2 = 10 * this.instanceRandom.nextDouble();
176 | attrib3 = 10 * this.instanceRandom.nextDouble();
177 |
178 | // determine class
179 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(attrib1, attrib2, attrib3);
180 | }while(group != label);
181 |
182 | //Add Noise
183 | if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) {
184 | group = (group == 0 ? 1 : 0);
185 | }
186 |
187 | // construct instance
188 | InstancesHeader header = getHeader();
189 | Instance inst = new DenseInstance(header.numAttributes());
190 | inst.setValue(0, attrib1);
191 | inst.setValue(1, attrib2);
192 | inst.setValue(2, attrib3);
193 | inst.setDataset(header);
194 | inst.setClassValue(group);
195 | return new InstanceExample(inst);
196 | }
197 |
198 | @Override
199 | public void restart() {
200 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
201 | this.nextClassShouldBeZero = false;
202 | }
203 |
204 | @Override
205 | public void getDescription(StringBuilder sb, int indent) {
206 | // TODO Auto-generated method stub
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/STAGGERGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * STAGGERGenerator.java
3 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
4 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.yahoo.labs.samoa.instances.Attribute;
23 | import com.yahoo.labs.samoa.instances.DenseInstance;
24 | import moa.core.FastVector;
25 | import com.yahoo.labs.samoa.instances.Instance;
26 | import com.yahoo.labs.samoa.instances.Instances;
27 |
28 | import java.util.Random;
29 | import moa.core.InstanceExample;
30 |
31 | import com.yahoo.labs.samoa.instances.InstancesHeader;
32 | import moa.core.ObjectRepository;
33 | import moa.options.AbstractOptionHandler;
34 | import com.github.javacliparser.FlagOption;
35 | import com.github.javacliparser.FloatOption;
36 | import com.github.javacliparser.IntOption;
37 | import moa.streams.InstanceStream;
38 | import moa.tasks.TaskMonitor;
39 |
40 | /**
41 | * Stream generator for STAGGER Concept functions.
42 | *
43 | * Generator described in the paper:
44 | * Jeffrey C. Schlimmer and Richard H. Granger Jr.
45 | * "Incremental Learning from Noisy Data",
46 | * Machine Learning 1: 317-354 1986.
47 | *
48 | * Notes:
49 | * The built in functions are based on the paper (page 341).
50 | *
51 | * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
52 | * @version $Revision: 7 $
53 | */
54 | public class STAGGERGenerator extends AbstractOptionHandler implements
55 | InstanceStream {
56 |
57 | @Override
58 | public String getPurposeString() {
59 | return "Generates STAGGER Concept functions.";
60 | }
61 |
62 | private static final long serialVersionUID = 1L;
63 |
64 | public IntOption instanceRandomSeedOption = new IntOption(
65 | "instanceRandomSeed", 'i',
66 | "Seed for random generation of instances.", 1);
67 |
68 | public IntOption functionOption = new IntOption("function", 'f',
69 | "Classification function used, as defined in the original paper.",
70 | 1, 1, 3);
71 |
72 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
73 | "Percentage of minority class examples", 0.1, 0, 1);
74 |
75 | protected interface ClassFunction {
76 |
77 | public int determineClass(int size, int color, int shape);
78 | }
79 |
80 | protected static ClassFunction[] classificationFunctions = {
81 | // function 1
82 | new ClassFunction() {
83 |
84 | @Override
85 | public int determineClass(int size, int color, int shape) {
86 | return (size == 0 && color == 0) ? 1 : 0; //size==small && color==red
87 | }
88 | },
89 | // function 2
90 | new ClassFunction() {
91 |
92 | @Override
93 | public int determineClass(int size, int color, int shape) {
94 | return (color == 2 || shape == 0) ? 1 : 0; //color==green || shape==circle
95 | }
96 | },
97 | // function 3
98 | new ClassFunction() {
99 |
100 | @Override
101 | public int determineClass(int size, int color, int shape) {
102 | return (size == 1 || size == 2) ? 1 : 0; // size==medium || size==large
103 | }
104 | }
105 | };
106 |
107 | protected InstancesHeader streamHeader;
108 |
109 | protected Random instanceRandom;
110 |
111 | protected boolean nextClassShouldBeZero;
112 |
113 | @Override
114 | protected void prepareForUseImpl(TaskMonitor monitor,
115 | ObjectRepository repository) {
116 | // generate header
117 | FastVector attributes = new FastVector();
118 |
119 | FastVector sizeLabels = new FastVector();
120 | sizeLabels.addElement("small");
121 | sizeLabels.addElement("medium");
122 | sizeLabels.addElement("large");
123 | attributes.addElement(new Attribute("size", sizeLabels));
124 |
125 | FastVector colorLabels = new FastVector();
126 | colorLabels.addElement("red");
127 | colorLabels.addElement("blue");
128 | colorLabels.addElement("green");
129 | attributes.addElement(new Attribute("color", colorLabels));
130 |
131 | FastVector shapeLabels = new FastVector();
132 | shapeLabels.addElement("circle");
133 | shapeLabels.addElement("square");
134 | shapeLabels.addElement("triangle");
135 | attributes.addElement(new Attribute("shape", shapeLabels));
136 |
137 | FastVector classLabels = new FastVector();
138 | classLabels.addElement("false");
139 | classLabels.addElement("true");
140 | attributes.addElement(new Attribute("class", classLabels));
141 | this.streamHeader = new InstancesHeader(new Instances(
142 | getCLICreationString(InstanceStream.class), attributes, 0));
143 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
144 | restart();
145 | }
146 |
147 | @Override
148 | public long estimatedRemainingInstances() {
149 | return -1;
150 | }
151 |
152 | @Override
153 | public InstancesHeader getHeader() {
154 | return this.streamHeader;
155 | }
156 |
157 | @Override
158 | public boolean hasMoreInstances() {
159 | return true;
160 | }
161 |
162 | @Override
163 | public boolean isRestartable() {
164 | return true;
165 | }
166 |
167 | @Override
168 | public InstanceExample nextInstance() {
169 |
170 | int size = 0, color = 0, shape = 0, group = 0;
171 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
172 |
173 | do {
174 | // generate attributes
175 | size = this.instanceRandom.nextInt(3);
176 | color = this.instanceRandom.nextInt(3);
177 | shape = this.instanceRandom.nextInt(3);
178 |
179 | // determine class
180 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(size, color, shape);
181 | }while(group != label);
182 |
183 | // construct instance
184 | InstancesHeader header = getHeader();
185 | Instance inst = new DenseInstance(header.numAttributes());
186 | inst.setValue(0, size);
187 | inst.setValue(1, color);
188 | inst.setValue(2, shape);
189 | inst.setDataset(header);
190 | inst.setClassValue(group);
191 | return new InstanceExample(inst);
192 | }
193 |
194 | @Override
195 | public void restart() {
196 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
197 | this.nextClassShouldBeZero = false;
198 | }
199 |
200 | @Override
201 | public void getDescription(StringBuilder sb, int indent) {
202 | // TODO Auto-generated method stub
203 | }
204 | }
205 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/SineGenerator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * SineGenerator.java
3 | * Copyright (C) 2016 Instituto Federal de Pernambuco
4 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br)
5 | *
6 | * This program is free software; you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation; either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * This program is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with this program. If not, see .
18 | *
19 | */
20 | package moa.streams.generators.imbalanced;
21 |
22 | import com.github.javacliparser.FlagOption;
23 | import com.github.javacliparser.FloatOption;
24 | import com.github.javacliparser.IntOption;
25 | import com.yahoo.labs.samoa.instances.Attribute;
26 | import com.yahoo.labs.samoa.instances.DenseInstance;
27 | import com.yahoo.labs.samoa.instances.Instance;
28 | import com.yahoo.labs.samoa.instances.Instances;
29 | import com.yahoo.labs.samoa.instances.InstancesHeader;
30 | import java.util.ArrayList;
31 | import java.util.List;
32 | import java.util.Random;
33 | import moa.core.InstanceExample;
34 |
35 | import moa.core.ObjectRepository;
36 | import moa.options.AbstractOptionHandler;
37 | import moa.streams.InstanceStream;
38 | import moa.tasks.TaskMonitor;
39 |
40 | /**
41 | * 1.SINE1. Abrupt concept drift, noise-free examples. It has two relevant
42 | * attributes. Each attributes has values uniformly distributed in [0; 1]. In
43 | * the first context all points below the curve y = sin(x) are classified as
44 | * positive. After the context change the classification is reversed.
45 | * 2.SINE2. The same two relevant attributes. The classification function is
46 | * y < 0.5 + 0.3 sin(3 * PI * x). After the context change the classification
47 | * is reversed.
48 | * 3.SINIRREL1. Presence of irrelevant attributes. The same classification
49 | * function of SINE1 but the examples have two more random attributes
50 | * with no influence on the classification function.
51 | * 4.SINIRREL2. The same classification function of SINE2 but the examples
52 | * have two more random attributes with no influence on the classification
53 | * function.
54 | * Based on proposal by "Gama, Joao, et al. "Learning with drift
55 | * detection." Advances in artificial intelligence–SBIA 2004. Springer Berlin
56 | * Heidelberg, 2004. 286-295."
57 | *
58 | * @author Paulo Gonçalves (paulogoncalves@recife.ifpe.edu.br)
59 | * @version $Revision: 1 $
60 | */
61 | public class SineGenerator extends AbstractOptionHandler implements
62 | InstanceStream {
63 |
64 | public static final int NUM_IRRELEVANT_ATTRIBUTES = 2;
65 |
66 | public IntOption instanceRandomSeedOption = new IntOption(
67 | "instanceRandomSeed", 'i',
68 | "Seed for random generation of instances.", 1);
69 |
70 | public IntOption functionOption = new IntOption("function", 'f',
71 | "Classification function used, as defined in the original paper.",
72 | 1, 1, 4);
73 |
74 | public FlagOption suppressIrrelevantAttributesOption = new FlagOption(
75 | "suppressIrrelevantAttributes", 's',
76 | "Reduce the data to only contain 2 relevant numeric attributes.");
77 |
78 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
79 | "Percentage of minority class examples", 0.1, 0, 1);
80 |
81 | protected InstancesHeader streamHeader;
82 |
83 | protected Random instanceRandom;
84 |
85 | protected boolean nextClassShouldBeZero;
86 |
87 | protected interface ClassFunction {
88 |
89 | public int determineClass(double x, double y);
90 | }
91 |
92 | protected static ClassFunction[] classificationFunctions = {
93 | // Values below the curve y = sin(x) are classified as positive.
94 | new ClassFunction() {
95 |
96 | @Override
97 | public int determineClass(double x, double y) {
98 | return (y < Math.sin(x)) ? 0 : 1;
99 | }
100 | },
101 | // Values below the curve y = sin(x) are classified as negative.
102 | new ClassFunction() {
103 |
104 | @Override
105 | public int determineClass(double x, double y) {
106 | return (y >= Math.sin(x)) ? 0 : 1;
107 | }
108 | },
109 | // Values below the curve y = 0.5 + 0.3*sin(3*PI*x) are classified as positive.
110 | new ClassFunction() {
111 |
112 | @Override
113 | public int determineClass(double x, double y) {
114 | return (y < 0.5 + 0.3 * Math.sin(3 * Math.PI * x)) ? 0 : 1;
115 | }
116 | },
117 | // Values below the curve y = 0.5 + 0.3*sin(3*PI*x) are classified as negative.
118 | new ClassFunction() {
119 |
120 | @Override
121 | public int determineClass(double x, double y) {
122 | return (y >= 0.5 + 0.3 * Math.sin(3 * Math.PI * x)) ? 0 : 1;
123 | }
124 | },};
125 |
126 | @Override
127 | public void getDescription(StringBuilder sb, int indent) {
128 |
129 | }
130 |
131 | @Override
132 | public InstancesHeader getHeader() {
133 | return this.streamHeader;
134 | }
135 |
136 | @Override
137 | public long estimatedRemainingInstances() {
138 | return -1;
139 | }
140 |
141 | @Override
142 | public boolean hasMoreInstances() {
143 | return true;
144 | }
145 |
146 | @Override
147 | public InstanceExample nextInstance() {
148 | double a1 = 0, a2 = 0, group = 0;
149 |
150 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
151 |
152 | do {
153 | a1 = this.instanceRandom.nextDouble();
154 | a2 = this.instanceRandom.nextDouble();
155 | group = classificationFunctions[this.functionOption.getValue() - 1].determineClass(a1, a2);
156 | }while(group != label);
157 |
158 | // construct instance
159 | InstancesHeader header = getHeader();
160 | Instance inst = new DenseInstance(header.numAttributes());
161 | inst.setValue(0, a1);
162 | inst.setValue(1, a2);
163 | inst.setDataset(header);
164 | if (!this.suppressIrrelevantAttributesOption.isSet()) {
165 | for (int i = 0; i < NUM_IRRELEVANT_ATTRIBUTES; i++) {
166 | inst.setValue(i + 2, this.instanceRandom.nextDouble());
167 | }
168 | }
169 | inst.setClassValue(group);
170 | return new InstanceExample(inst);
171 | }
172 |
173 | @Override
174 | public boolean isRestartable() {
175 | return true;
176 | }
177 |
178 | @Override
179 | public void restart() {
180 | this.instanceRandom = new Random(
181 | this.instanceRandomSeedOption.getValue());
182 | this.nextClassShouldBeZero = false;
183 | }
184 |
185 | @Override
186 | protected void prepareForUseImpl(TaskMonitor monitor,
187 | ObjectRepository repository) {
188 | ArrayList attributes = new ArrayList();
189 |
190 | int numAtts = 2;
191 | if (!this.suppressIrrelevantAttributesOption.isSet()) {
192 | numAtts += NUM_IRRELEVANT_ATTRIBUTES;
193 | }
194 | for (int i = 0; i < numAtts; i++) {
195 | attributes.add(new Attribute("att" + (i + 1)));
196 | }
197 |
198 | List classLabels = new ArrayList();
199 | classLabels.add("positive");
200 | classLabels.add("negative");
201 | Attribute classAtt = new Attribute("class", classLabels);
202 | attributes.add(classAtt);
203 |
204 | this.streamHeader = new InstancesHeader(new Instances(
205 | getCLICreationString(InstanceStream.class), attributes, 0));
206 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
207 | restart();
208 | }
209 | }
210 |
--------------------------------------------------------------------------------
/src/main/java/moa/streams/generators/imbalanced/TextGenerator.java:
--------------------------------------------------------------------------------
1 | package moa.streams.generators.imbalanced;
2 |
3 | import com.github.javacliparser.FloatOption;
4 |
5 | /*
6 | *
7 | * Licensed under the Apache License, Version 2.0 (the "License");
8 | * you may not use this file except in compliance with the License.
9 | * You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | *
19 | */
20 |
21 | import com.github.javacliparser.IntOption;
22 | import com.yahoo.labs.samoa.instances.*;
23 | import moa.core.InstanceExample;
24 | import moa.core.ObjectRepository;
25 | import moa.options.AbstractOptionHandler;
26 | import moa.streams.InstanceStream;
27 | import moa.tasks.TaskMonitor;
28 |
29 | import java.util.ArrayList;
30 | import java.util.Random;
31 |
32 | /**
33 | * Text generator that simulates sentiment analysis on tweets.
34 | */
35 | public class TextGenerator extends AbstractOptionHandler implements InstanceStream {
36 |
37 | private static final long serialVersionUID = 3028905554604259131L;
38 |
39 | public IntOption numAttsOption = new IntOption("numAtts", 'a',
40 | "The number of attributes to generate.", 1000, 0, Integer.MAX_VALUE);
41 |
42 | public IntOption instanceRandomSeedOption = new IntOption(
43 | "instanceRandomSeed", 'i',
44 | "Seed for random generation of instances.", 1);
45 |
46 | public FloatOption imbalanceRatio = new FloatOption("imbalanceRatio", 'm',
47 | "Percentage of minority class examples", 0.1, 0, 1);
48 |
49 | protected InstancesHeader streamHeader;
50 |
51 | protected Random instanceRandom;
52 |
53 | protected int[] wordTwitterGenerator;
54 | protected double[] freqTwitterGenerator;
55 | protected double[] sumFreqTwitterGenerator;
56 | protected int[] classTwitterGenerator;
57 |
58 | protected int sizeTable;
59 | protected double probPositive = 0.1;
60 | protected double probNegative = 0.1;
61 | protected double zipfExponent = 1.5;
62 | protected double lengthTweet = 15;
63 |
64 | protected int countTweets = 0;
65 |
66 | @Override
67 | public InstancesHeader getHeader() {
68 | return this.streamHeader;
69 | }
70 |
71 | @Override
72 | public long estimatedRemainingInstances() {
73 | return -1;
74 | }
75 |
76 | @Override
77 | public boolean hasMoreInstances() {
78 | return true;
79 | }
80 |
81 | @Override
82 | public InstanceExample nextInstance() {
83 | int[] votes;
84 | double[] attVals;
85 | attVals = new double[this.numAttsOption.getValue() + 1];
86 | Instance inst;
87 |
88 | int label = instanceRandom.nextDouble() < imbalanceRatio.getValue() ? 1 : 0;
89 |
90 | do {
91 | do {
92 | int length = (int) (lengthTweet * (1.0 + this.instanceRandom.nextGaussian()));
93 | if (length < 1) length = 1;
94 | votes = new int[3];
95 | for (int j = 0; j < length; j++) {
96 | double rand = this.instanceRandom.nextDouble();
97 | //binary search
98 | int i = 0;
99 | int min = 0;
100 | int max = sizeTable - 1;
101 | int mid;
102 | do {
103 | mid = (min + max) / 2;
104 | if (rand > this.sumFreqTwitterGenerator[mid]) {
105 | min = mid + 1;
106 | } else {
107 | max = mid - 1;
108 | }
109 | } while ((this.sumFreqTwitterGenerator[mid] != rand) && (min <= max));
110 |
111 | attVals[this.wordTwitterGenerator[mid]] = 1;
112 | votes[this.classTwitterGenerator[mid]]++;
113 |
114 | }
115 | } while (votes[1] == votes[2]);
116 |
117 | inst = new SparseInstance(1.0, attVals);
118 | inst.setDataset(getHeader());
119 | inst.setClassValue((votes[1] > votes[2]) ? 0 : 1);
120 | }while(inst.classValue() != label);
121 |
122 | this.countTweets++;
123 | return new InstanceExample(inst);
124 | }
125 |
126 | @Override
127 | public boolean isRestartable() {
128 | return true;
129 | }
130 |
131 | @Override
132 | public void restart() {
133 |
134 | this.sizeTable = this.numAttsOption.getValue();
135 |
136 | //Prepare table of words to generate tweets
137 | this.wordTwitterGenerator = new int[sizeTable];
138 | this.freqTwitterGenerator = new double[sizeTable];
139 | this.sumFreqTwitterGenerator = new double[sizeTable];
140 | this.classTwitterGenerator = new int[sizeTable];
141 |
142 | this.countTweets = 0;
143 |
144 | double sum = 0;
145 | this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
146 | for (int i = 0; i < this.sizeTable; i++) {
147 | this.wordTwitterGenerator[i] = i + 1;
148 | this.freqTwitterGenerator[i] = 1.0 / Math.pow(i + 1, zipfExponent);
149 | sum += this.freqTwitterGenerator[i];
150 | this.sumFreqTwitterGenerator[i] = sum;
151 | double rand = this.instanceRandom.nextDouble();
152 | this.classTwitterGenerator[i] = (rand < probPositive ? 1 : (rand < probNegative + probPositive ? 2 : 0));
153 | }
154 | for (int i = 0; i < this.sizeTable; i++) {
155 | this.freqTwitterGenerator[i] /= sum;
156 | this.sumFreqTwitterGenerator[i] /= sum;
157 | }
158 |
159 | }
160 |
161 | @Override
162 | protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
163 | generateHeader();
164 | restart();
165 | }
166 |
167 | @Override
168 | public void getDescription(StringBuilder sb, int indent) {
169 |
170 | }
171 | private void generateHeader() {
172 | ArrayList classLabels = new ArrayList();
173 | for (int i = 0; i < 2; i++) {
174 | classLabels.add("class" + (i + 1));
175 | }
176 | ArrayList attributes = new ArrayList();
177 | for (int i = 0; i < this.numAttsOption.getValue(); i++) {
178 | attributes.add(new Attribute("att" + (i + 1), classLabels));
179 | }
180 | attributes.add(new Attribute("class", classLabels));
181 | this.streamHeader = new InstancesHeader(new Instances(
182 | getCLICreationString(InstanceStream.class), attributes, 0));
183 | this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
184 | }
185 |
186 |
187 | public void changePolarity(int numberWords) {
188 | for (int i = 0; i < numberWords; ) {
189 | int randWord = this.instanceRandom.nextInt(this.sizeTable);
190 | int polarity = this.classTwitterGenerator[randWord];
191 | if (polarity == 1) {
192 | this.classTwitterGenerator[i] = 2;
193 | i++;
194 | }
195 | if (polarity == 2) {
196 | this.classTwitterGenerator[i] = 1;
197 | i++;
198 | }
199 | }
200 | }
201 |
202 | public void changeFreqWords(int numberWords) {
203 | for (int i = 0; i < numberWords; i++) {
204 | int randWordTo = this.instanceRandom.nextInt(this.sizeTable);
205 | int randWordFrom = this.instanceRandom.nextInt(this.sizeTable);
206 | this.wordTwitterGenerator[randWordTo] = randWordFrom;
207 | this.wordTwitterGenerator[randWordFrom] = randWordTo;
208 | }
209 | }
210 |
211 |
212 | }
213 |
--------------------------------------------------------------------------------