[jboss-svn-commits] JBL Code SVN: r24137 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 6 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Fri Nov 28 10:25:28 EST 2008
Author: gizil
Date: 2008-11-28 10:25:28 -0500 (Fri, 28 Nov 2008)
New Revision: 24137
Added:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Solution.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SolutionSet.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/PrunerStats.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TreeStats.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Entropy.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/GainRatio.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Heuristic.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/MinEntropy.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/RandomInfo.java
Removed:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/BoostedTester.java
Modified:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.stats
Log:
final version (integrated, structured rules)
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -3,6 +3,7 @@
import java.util.ArrayList;
import org.drools.learner.builder.Solution;
+import org.drools.learner.builder.SolutionSet;
import org.drools.learner.builder.test.SingleTreeTester;
//import org.drools.learner.eval.ErrorEstimate;
//import org.drools.learner.eval.TestSample;
@@ -18,33 +19,55 @@
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(DecisionTreePruner.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(DecisionTreePruner.class, SimpleLogger.DEBUG);
- private PrunerStats best_stats;
+// private PrunerStats best_stats;
private double INIT_ALPHA = 0.5d;
private static final double EPSILON = 0.0d;//0.0000000001;
- private double best_error;
+// private double best_error;
private Solution best_solution;
- ArrayList<Solution> pruned_sol;
+ private ArrayList<Solution> pruned_sol;
+
+
+ private PrunerStats best_stats_everfound;
public DecisionTreePruner() {
- best_error = 1.0;
+// best_error = 1.0;
- best_stats = new PrunerStats(0.0);//proc.getAlphaEstimate());
+ best_stats_everfound = new PrunerStats(1.0);//proc.getAlphaEstimate());
pruned_sol = new ArrayList<Solution>();
}
public Solution getBestSolution() {
return best_solution;
}
- public void prun_to_estimate(Solution sol) {
+ public void prun_to_estimate(SolutionSet sol_set) {
/*
* The best tree is selected from this series of trees with the classification error not exceeding
* an expected error rate on some test set (cross-validation error),
* which is done at the second stage.
*/
+ int i =0;
+ for (Solution sol: sol_set.getSolutions()) {
+ boolean updated = this.prun_to_estimate(sol);
+ if (updated) {
+ sol_set.setBestSolutionId(i);
+ }
+ }
+
+
+
+ }
+
+ public boolean prun_to_estimate(Solution sol) {
+ /*
+ * The best tree is selected from this series of trees with the classification error not exceeding
+ * an expected error rate on some test set (cross-validation error),
+ * which is done at the second stage.
+ */
+
DecisionTree dt = sol.getTree();
dt.calc_num_node_leaves(dt.getRoot());
@@ -62,24 +85,32 @@
boolean better_found = false;
- int sid = 0, best_st=0;
+ int sid = 0;
int best_id = 0;
- double best_error = 1.0d;
+// double best_error = 1.0d;
System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
ArrayList<PrunerStats> trees = search._getTreeSequenceStats();
for (PrunerStats st: trees ){
- if (st.getErrorEstimation() <= best_error) {
- best_error = st.getErrorEstimation();
+ if (st.getErrorEstimation() < best_stats_everfound.getErrorEstimation() ||
+ (st.getErrorEstimation() == best_stats_everfound.getErrorEstimation() &&
+ st.getNum_terminal_nodes() < best_stats_everfound.getNum_terminal_nodes())) {
+ best_stats_everfound.iteration_id(st.iteration_id());
+ best_stats_everfound.setErrorEstimation(st.getErrorEstimation());
+ best_stats_everfound.setTrainError(st.getTrainError());
+ best_stats_everfound.setNum_terminal_nodes(st.getNum_terminal_nodes());
+ best_stats_everfound.setAlpha(st.getAlpha());
best_id = sid;
- best_st = st.iteration_id();
+// best_st = st.iteration_id();
better_found = true;
+ } else {
+
}
- System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getErrorEstimation() + "\t "+ st.getTrainError()+ "\t "+ st.getAlpha() );
+ System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ 100.0d*st.getErrorEstimation() + "\t "+ 100.0d*st.getTrainError()+ "\t "+ st.getAlpha() );
sid++;
}
System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getErrorEstimation() + "\t "+ trees.get(best_id).getTrainError() + "\t "+ trees.get(best_id).getAlpha() );
-// System.exit(0);
+ //System.exit(0);
int N = dt.getTestingDataSize();// procedure.getTestDataSize(dt_i);
double standart_error_estimate = Math.sqrt((trees.get(best_id).getErrorEstimation() * (1- trees.get(best_id).getErrorEstimation())/ (double)N));
@@ -88,8 +119,8 @@
for (int i = search.tree_sequence.size()-1; i>0; i--) {
NodeUpdate nu = search.tree_sequence.get(i);
- System.out.println(nu.iteration_id +">"+ best_st);
- if (nu.iteration_id > best_st ) {
+ System.out.println(nu.iteration_id +">"+ best_stats_everfound.iteration_id());
+ if (nu.iteration_id > best_stats_everfound.iteration_id() ) {
search.add_back(nu.node_update, nu.old_node);
SingleTreeTester t = new SingleTreeTester(search.tree_sol.getTree());
Stats train = t.test(search.tree_sol.getList());
@@ -103,7 +134,9 @@
}
best_solution = search.tree_sol;
+ return true;
}
+ return false;
}
@@ -111,7 +144,7 @@
public void prun_tree(Solution sol) {
double epsilon = 0.0000001 * numExtraMisClassIfPrun(sol.getTree().getRoot());
- TreeSequenceProc search = new TreeSequenceProc(sol, new AnAlphaProc(best_stats.getAlpha(), epsilon));
+ TreeSequenceProc search = new TreeSequenceProc(sol, new AnAlphaProc(best_stats_everfound.getAlpha(), epsilon));
search.iterate_trees(0);
//search.getTreeSequence()// to go back
Deleted: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/BoostedTester.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/BoostedTester.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/BoostedTester.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -1,92 +0,0 @@
-package org.drools.learner.builder;
-
-import java.util.ArrayList;
-
-import org.drools.learner.DecisionTree;
-import org.drools.learner.Domain;
-import org.drools.learner.Instance;
-import org.drools.learner.InstanceList;
-import org.drools.learner.Stats;
-import org.drools.learner.eval.ClassDistribution;
-import org.drools.learner.tools.LoggerFactory;
-import org.drools.learner.tools.SimpleLogger;
-import org.drools.learner.tools.Util;
-
-public class BoostedTester extends Tester{
-
- private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(BoostedTester.class, SimpleLogger.DEFAULT_LEVEL);
- private static SimpleLogger slog = LoggerFactory.getSysOutLogger(BoostedTester.class, SimpleLogger.DEFAULT_LEVEL);
-
- private ArrayList<DecisionTree> trees;
- private ArrayList<Double> accuracy;
- private Domain targetDomain;
-
- public BoostedTester(ArrayList<DecisionTree> forest, ArrayList<Double> _accuracy) {
- trees = forest;
- accuracy = _accuracy;
- targetDomain = forest.get(0).getTargetDomain();
- }
-
- public Stats test(InstanceList data) {
-
- Stats evaluation = new Stats(data.getSchema().getObjectClass()) ; //represent.getObjClass());
-
- int i = 0;
- for (Instance instance : data.getInstances()) {
- Object forest_decision = this.voteOn(instance);
- Integer result = evaluate(targetDomain, instance, forest_decision);
-
- //flog.debug(Util.ntimes("#\n", 1)+i+ " <START> TEST: instant="+ instance + " = target "+ result);
- if (i%1000 ==0 && slog.stat() != null)
- slog.stat().stat(".");
-
- evaluation.change(result, 1);
- i ++;
- }
- return evaluation;
-
- //printStats(evaluation, executionSignature);
- }
-
- public Object voteOn(Instance i) {
- ClassDistribution classification = new ClassDistribution(targetDomain);
-
- for (int j = 0; j< trees.size() ; j ++) {
- Object vote = trees.get(j).vote(i);
- if (vote != null) {
- classification.change(vote, accuracy.get(j));
- //classification.change(Util.sum(), accuracy.get(j));
- } else {
- // TODO add an unknown value
- //classification.change(-1, 1);
- if (flog.error() !=null)
- flog.error().log(Util.ntimes("\n", 10)+"Unknown situation at tree: " + j + " for fact "+ i);
- System.exit(0);
- }
- if (slog.debug() != null)
- slog.debug().log("Vote "+accuracy.get(j)+" for "+vote + "\n");
- }
- classification.evaluateMajority();
- Object winner = classification.get_winner_class();
- if (slog.debug() != null)
- slog.debug().log("Winner = "+winner + "\n");
-
- double ratio = 0.0;
- if (classification.get_num_ideas() == 1) {
- //100 %
- ratio = 1.0d;
- return winner;
- } else {
- double num_votes = classification.getVoteFor(winner);
- ratio = (num_votes/(double) trees.size());
- // TODO if the ratio is smaller than some number => reject
- }
- return winner;
-
- }
-
- public void printStats(final Stats evaluation, String executionSignature, boolean append) {
- super.printStats(evaluation, executionSignature, append);
- }
-
-}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -9,9 +9,9 @@
import org.drools.learner.LeafNode;
import org.drools.learner.TreeNode;
import org.drools.learner.eval.AttributeChooser;
-import org.drools.learner.eval.Heuristic;
import org.drools.learner.eval.InformationContainer;
import org.drools.learner.eval.InstDistribution;
+import org.drools.learner.eval.heuristic.Heuristic;
import org.drools.learner.eval.stopping.StoppingCriterion;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.Util;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -16,11 +16,13 @@
import org.drools.learner.builder.test.Tester;
import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
import org.drools.learner.eval.CrossValidation;
-import org.drools.learner.eval.Entropy;
import org.drools.learner.eval.ErrorEstimate;
-import org.drools.learner.eval.GainRatio;
-import org.drools.learner.eval.Heuristic;
import org.drools.learner.eval.TestSample;
+import org.drools.learner.eval.heuristic.Entropy;
+import org.drools.learner.eval.heuristic.GainRatio;
+import org.drools.learner.eval.heuristic.Heuristic;
+import org.drools.learner.eval.heuristic.MinEntropy;
+import org.drools.learner.eval.heuristic.RandomInfo;
import org.drools.learner.eval.stopping.EstimatedNodeSize;
import org.drools.learner.eval.stopping.ImpurityDecrease;
import org.drools.learner.eval.stopping.MaximumDepth;
@@ -99,6 +101,16 @@
return createSingleC45(wm, obj_class, new GainRatio(), criteria, null);
}
+ public static DecisionTree createSingleC45E_worst(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createSingleC45(wm, obj_class, new MinEntropy(), criteria, null);
+ }
+
+ public static DecisionTree createSingleC45Random(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createSingleC45(wm, obj_class, new RandomInfo(), criteria, null);
+ }
+
public static DecisionTree createSingleC45E_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
criteria.add(new EstimatedNodeSize(0.5));
@@ -118,14 +130,14 @@
public static DecisionTree createSingleC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
- criteria.add(new EstimatedNodeSize(0.5));
- return createSingleC45(wm, obj_class, new Entropy(), criteria ,pruner);
+ criteria.add(new EstimatedNodeSize(0.05));
+ return createSingleC45(wm, obj_class, new Entropy(), criteria, pruner);
}
public static DecisionTree createSingleC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
- criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new EstimatedNodeSize(0.05));
return createSingleC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
@@ -142,8 +154,8 @@
/* create the memory */
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+// mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
+// mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
mem.processTestSet();
for (StoppingCriterion sc: criteria) {
@@ -164,8 +176,9 @@
Tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
if (pruner != null) {
- for (Solution sol: product.getSolutions())
- pruner.prun_to_estimate(sol);
+// for (Solution sol: product.getSolutions())
+// pruner.prun_to_estimate(sol);
+ pruner.prun_to_estimate(product);
Solution s2 = pruner.getBestSolution();
Tester t2 = single_builder.getTester(pruner.getBestSolution().getTree());
StatsPrinter.printLatexComment("Pruned TREE", executionSignature, true);
@@ -214,12 +227,14 @@
public static DecisionTree createBagC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.05));
return createBagC45(wm, obj_class, new Entropy(), criteria ,pruner);
}
public static DecisionTree createBagC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.05));
return createBagC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
@@ -268,10 +283,16 @@
if (pruner != null) {
- for (Solution sol: product.getSolutions())
- pruner.prun_to_estimate(sol);
-
+// for (Solution sol: product.getSolutions())
+// pruner.prun_to_estimate(sol);
+//
+// Solution s2 = pruner.getBestSolution();
+// product.setBestSolutionId(s2.getTree().getId());
+// for (Solution sol: product.getSolutions())
+// pruner.prun_to_estimate(sol);
+ pruner.prun_to_estimate(product);
Solution s2 = pruner.getBestSolution();
+
Tester t2 = forest.getTester(s2.getTree());
StatsPrinter.printLatexComment("Best Pruned Tree", executionSignature, true);
StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, true);
@@ -323,12 +344,14 @@
public static DecisionTree createBoostC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.05));
return createBoostC45(wm, obj_class, new Entropy(), criteria ,pruner);
}
public static DecisionTree createBoostC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
DecisionTreePruner pruner = new DecisionTreePruner();
ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.05));
return createBoostC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
@@ -374,14 +397,21 @@
if (pruner != null) {
- for (Solution sol: product.getSolutions())
- pruner.prun_to_estimate(sol);
-
+// for (Solution sol: product.getSolutions())
+// pruner.prun_to_estimate(sol);
+ pruner.prun_to_estimate(product);
Solution s2 = pruner.getBestSolution();
+// System.out.println(s2.getTree().getId());
+// product.setBestSolutionId(s2.getTree().getId());
Tester t2 = boosted_forest.getTester(pruner.getBestSolution().getTree());
StatsPrinter.printLatexComment("Best Pruned Tree", executionSignature, true);
StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, true);
+ Tester t2_global = boosted_forest.getTester(s2.getTree());
+ StatsPrinter.printLatexComment("Best Original Tree(Global)", executionSignature, true);
+ StatsPrinter.printLatex(t2_global.test(product.getTrainSet()), t2_global.test(product.getTestSet()), executionSignature, true);
+
+
}
boosted_forest.getBestSolution().getTree().setSignature(executionSignature);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -8,8 +8,8 @@
import org.drools.learner.LeafNode;
import org.drools.learner.TreeNode;
import org.drools.learner.eval.AttributeChooser;
-import org.drools.learner.eval.Heuristic;
import org.drools.learner.eval.InstDistribution;
+import org.drools.learner.eval.heuristic.Heuristic;
import org.drools.learner.tools.Util;
public class ID3Learner extends Learner {
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Solution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Solution.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Solution.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,73 @@
+package org.drools.learner.builder;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Stats;
+
+public class Solution {
+
+ private DecisionTree dt;
+ private InstanceList training_list;
+
+ private InstanceList test_list;
+ private Stats train_stats, test_stats;
+
+ public Solution(DecisionTree _dt, InstanceList list) {
+ dt = _dt;
+ training_list = list;
+ }
+
+
+ public DecisionTree getTree() {
+ return dt;
+ }
+
+ public InstanceList getList() {
+ return training_list;
+ }
+
+ public InstanceList getTestList() {
+ return test_list;
+ }
+
+ public void setTestList(InstanceList test) {
+ test_list = test;
+ }
+
+ public void setTrainStats(Stats train) {
+ train_stats = train;
+ }
+
+ public void setTestStats(Stats test) {
+ test_stats = test;
+ }
+
+ public Stats getTrainStats() {
+ return train_stats;
+ }
+
+ public Stats getTestStats() {
+ return test_stats;
+ }
+
+ public double getTrainError() {
+ System.out.println("Total Train"+ train_stats.getTotal()+ ", size "+ training_list.getSize());
+ return (double)train_stats.getResult(Stats.INCORRECT)/(double)train_stats.getTotal();
+ }
+
+ public double getTestError() {
+ System.out.println("Total Test"+ test_stats.getTotal()+ ", size "+ test_list.getSize());
+ return (double)test_stats.getResult(Stats.INCORRECT)/(double)test_stats.getTotal();
+ }
+
+ public void changeTrainError(int change) {
+ train_stats.change(Stats.INCORRECT, change);
+ train_stats.change(Stats.CORRECT, -1*change);
+ }
+
+ public void setError(int change) {
+ // TODO test_stats.
+ System.out.println("Solution setError() doing nothing");
+ System.exit(0);
+ }
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Solution.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SolutionSet.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SolutionSet.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SolutionSet.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,106 @@
+package org.drools.learner.builder;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Memory;
+import org.drools.learner.Stats;
+
+public class SolutionSet {
+
+ private Memory mem;
+ private ArrayList<Solution> sol_trees;
+ private InstanceList global_train_set, global_test_set;
+ private Stats global_train_stats, global_test_stats;
+ private int best_solution_id;
+
+ public SolutionSet(Memory _mem) {
+ mem = _mem;
+ best_solution_id = 0;
+ sol_trees = new ArrayList<Solution>(1);
+ }
+
+ public Collection<String> getTargets() {
+ return mem.getClassInstances().getTargets();
+ }
+
+ public InstanceList getInputSpec() {
+ return mem.getClassInstances();
+ }
+
+// public InstanceList getValidationSet
+
+ public InstanceList getTrainSet() {
+ return mem.getTrainSet();
+ }
+
+ public InstanceList getTestSet() {
+ return mem.getTestSet();
+ }
+
+ public void addSolution(Solution s) {
+ sol_trees.add(s);
+ }
+
+ public void addSolution(DecisionTree tree, InstanceList set, InstanceList test_set) {
+ Solution x = new Solution(tree, set);
+ x.setTestList(test_set);
+ addSolution(x);
+ }
+
+ public ArrayList<Solution> getSolutions() {
+ return sol_trees;
+ }
+
+ public Stats getGlobalTrainStats() {
+ return global_train_stats;
+ }
+
+ public Stats getGlobalTestStats() {
+ return global_test_stats;
+ }
+
+ public void setGlobalTrainStats(Stats train) {
+ global_train_stats = train;
+ }
+
+ public void setGlobalTestStats(Stats test) {
+ global_test_stats = test;
+ }
+
+ public void setBestSolutionId(int best_id) {
+ best_solution_id = best_id;
+ }
+
+ public int getBestSolutionId() {
+ return best_solution_id;
+ }
+ public Solution getBestSolution() {
+ return sol_trees.get(best_solution_id);
+ }
+
+ public int getMinTestId() {
+ double min = 1.0;
+ int id = -1;
+ for (int i=0; i< sol_trees.size(); i++ ) {
+ double test_error = sol_trees.get(i).getTestError();
+ double train_error = sol_trees.get(i).getTrainError();
+ if (test_error < min) {
+ min = test_error;
+ id = i;
+ } else if (test_error == min) {
+ double train_old = sol_trees.get(id).getTrainError();
+ if (train_error < train_old) {
+ min = test_error;
+ id = i;
+ }
+ }
+
+ }
+ return id;
+
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SolutionSet.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -3,6 +3,7 @@
import java.util.List;
import org.drools.learner.Domain;
+import org.drools.learner.eval.heuristic.Heuristic;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -7,7 +7,7 @@
import org.drools.learner.Instance;
import org.drools.learner.InstanceComparator;
import org.drools.learner.QuantitativeDomain;
-import org.drools.learner.eval.Entropy;
+import org.drools.learner.eval.heuristic.Entropy;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/PrunerStats.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/PrunerStats.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/PrunerStats.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,61 @@
+package org.drools.learner.eval;
+
+public class PrunerStats extends TreeStats{
+
+
+ private int iteration_id;
+
+ private double cost_complexity;
+ private double alpha;
+// private double training_error;
+
+ public PrunerStats() {
+ iteration_id = 0;
+ }
+
+ public PrunerStats(TreeStats ts) {
+ super(ts.getTrainError(), ts.getErrorEstimation());
+ iteration_id = 0;
+ }
+
+ // to set an node update with the worst cross validated error
+ public PrunerStats(double error1) {
+ super(error1);
+ iteration_id = 0;
+ //test_cost = error;
+
+ }
+ public PrunerStats(double error1, double error2) {
+ super(error1, error2);
+ iteration_id = 0;
+ //test_cost = error;
+
+ }
+
+ public void iteration_id(int i) {
+ iteration_id = i;
+ }
+
+ public int iteration_id() {
+ return iteration_id;
+ }
+
+
+ public double getCost_complexity() {
+ return cost_complexity;
+ }
+
+ public void setCost_complexity(double cost_complexity) {
+ this.cost_complexity = cost_complexity;
+ }
+
+ public double getAlpha() {
+ return alpha;
+ }
+
+ public void setAlpha(double alpha) {
+ this.alpha = alpha;
+ }
+
+
+}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TreeStats.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TreeStats.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TreeStats.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,50 @@
+package org.drools.learner.eval;
+
+public class TreeStats{
+
+ //private int iteration_id;
+ private int num_terminal_nodes;
+ private double test_error;
+ private double train_error;
+
+
+ public TreeStats() {
+
+ }
+ public TreeStats(double error1, double error2) {
+ train_error = error1;
+ test_error = error2;
+ }
+
+ // to set an node update with the worst cross validated error
+ public TreeStats(double error) {
+// iteration_id = 0;
+ test_error = error;
+ train_error = 1.0d - error;
+ }
+
+ public int getNum_terminal_nodes() {
+ return num_terminal_nodes;
+ }
+
+ public void setNum_terminal_nodes(int num_terminal_nodes) {
+ this.num_terminal_nodes = num_terminal_nodes;
+ }
+
+ public double getErrorEstimation() {
+ return test_error;
+ }
+
+ public void setErrorEstimation(double valid_cost) {
+ this.test_error = valid_cost;
+ }
+
+ public double getTrainError() {
+ return train_error;
+ }
+
+ public void setTrainError(double resubstitution_cost) {
+ this.train_error = resubstitution_cost;
+ }
+
+}
\ No newline at end of file
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TreeStats.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Entropy.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Entropy.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,215 @@
+package org.drools.learner.eval.heuristic;
+
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.drools.learner.Domain;
+import org.drools.learner.Instance;
+import org.drools.learner.QuantitativeDomain;
+import org.drools.learner.eval.Categorizer;
+import org.drools.learner.eval.ClassDistribution;
+import org.drools.learner.eval.CondClassDistribution;
+import org.drools.learner.eval.InstDistribution;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
+
+public class Entropy implements Heuristic{
+
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(Entropy.class, SimpleLogger.DEFAULT_LEVEL);
+ //public Entropy
+ /*
+ * - chooses the best attribute,
+ * - return the choosen attributes evaluation results (information gain and/or gain ratio)
+ * can process categorical, and quantitative attribute domains
+ *
+ * used by:
+ * c45Alternator, c45Learner, c45Iterator
+ */
+
+ protected static double multiplier = 1.0;
+ protected double data_eval;
+ protected InstDistribution insts_by_target;
+ protected ArrayList<Instance> sorted_instances;
+ protected Domain domain;
+
+ public Entropy() {
+ //
+ }
+
+
+ public Entropy(double m) {
+ multiplier = m;
+ }
+ public void init(InstDistribution _insts_by_target) {
+ insts_by_target = _insts_by_target;
+ data_eval = calc_info(insts_by_target);
+ sorted_instances = null;
+ domain = null;
+ }
+
+
+ public double getEval(Domain attr_domain) {
+ CondClassDistribution insts_by_attr = info_attr(attr_domain);
+ return multiplier *(data_eval - Entropy.calc_info_attr(insts_by_attr));
+ }
+
+ public double getEval_cont(Domain attr_domain) {
+
+ double attribute_eval= 0.0d;
+ QuantitativeDomain trialDomain = QuantitativeDomain.createFromDomain(attr_domain);
+
+ Categorizer visitor = new Categorizer(insts_by_target);
+ visitor.findSplits(trialDomain);
+
+ // trial domain is modified
+ if (trialDomain.getNumIndices() > 1) {
+ CondClassDistribution insts_by_attr = info_contattr(visitor); //.getSortedInstances(), trialDomain);
+ attribute_eval = data_eval - Entropy.calc_info_attr(insts_by_attr);
+ }
+ domain = trialDomain;
+ sorted_instances = visitor.getSortedInstances();
+ return multiplier *attribute_eval;
+ }
+
+ public double getDataEval() {
+ return data_eval;
+ }
+
+ public Domain getDomain() {
+ return domain;
+ }
+
+ public ArrayList<Instance> getSortedInstances() {
+ return sorted_instances;
+ }
+
+ public double getWorstEval() {
+ return -1000.0d;
+ }
+
+ public CondClassDistribution info_attr(Domain attr_domain) {
+
+ Domain target_domain = insts_by_target.getClassDomain();
+
+ //flog.debug("What is the attributeToSplit? " + attr_domain);
+
+ /* initialize the hashtable */
+ CondClassDistribution insts_by_attr = new CondClassDistribution(attr_domain, target_domain);
+ insts_by_attr.setTotal(insts_by_target.getSum());
+
+ //flog.debug("Cond distribution for "+ attr_domain + " \n"+ insts_by_attr);
+
+ for (int category = 0; category<target_domain.getCategoryCount(); category++) {
+ Object targetCategory = target_domain.getCategory(category);
+
+ for (Instance inst: insts_by_target.getSupportersFor(targetCategory)) {
+ Object inst_attr_category = inst.getAttrValue(attr_domain.getFReferenceName());
+
+ Object inst_class = inst.getAttrValue(target_domain.getFReferenceName());
+
+ if (!targetCategory.equals(inst_class)) {
+ if (flog.error() != null)
+ flog.error().log("How the fuck they are not the same ? "+ targetCategory + " " + inst_class);
+ System.exit(0);
+ }
+ insts_by_attr.change(inst_attr_category, targetCategory, inst.getWeight()); //+1
+
+ }
+ }
+
+ return insts_by_attr;
+ }
+
+
+ /* calculates the information of a quantitative domain given the split indexes of instances
+ * a wrapper for the quantitative domain to be able to calculate the stats
+ * */
+ //public static double info_contattr(InstanceList data, Domain targetDomain, QuantitativeDomain splitDomain) {
+ public CondClassDistribution info_contattr(Categorizer visitor) {
+
+ List<Instance> data = visitor.getSortedInstances();
+ QuantitativeDomain splitDomain = visitor.getSplitDomain();
+ Domain targetDomain = insts_by_target.getClassDomain();
+ String targetAttr = targetDomain.getFReferenceName();
+
+ CondClassDistribution instances_by_attr = new CondClassDistribution(splitDomain, targetDomain);
+ instances_by_attr.setTotal(data.size());
+
+ int index = 0;
+ int split_index = 0;
+ Object attr_key = splitDomain.getCategory(split_index);
+ for (Instance i : data) {
+
+ if (index == splitDomain.getSplit(split_index).getIndex()+1 ) {
+ attr_key = splitDomain.getCategory(split_index+1);
+ split_index++;
+ }
+ Object targetKey = i.getAttrValue(targetAttr);
+ instances_by_attr.change(attr_key, targetKey, i.getWeight()); //+1
+
+ index++;
+ }
+
+ return instances_by_attr;
+// double sum = calc_info_attr(instances_by_attr);
+// return sum;
+
+ }
+
+ /*
+ * for both
+ */
+ public static double calc_info_attr( CondClassDistribution instances_by_attr) {
+ //Collection<Object> attributeValues = instances_by_attr.getAttributes();
+ double data_size = instances_by_attr.getTotal();
+ double sum = 0.0;
+ if (data_size>0)
+ for (int attr_idx=0; attr_idx<instances_by_attr.getNumCondClasses(); attr_idx++) {
+ Object attr_category = instances_by_attr.getCondClass(attr_idx);
+ double total_num_attr = instances_by_attr.getTotal_AttrCategory(attr_category);
+
+ if (total_num_attr > 0) {
+ double prob = total_num_attr / data_size;
+ //flog.debug("{("+total_num_attr +"/"+data_size +":"+prob +")* [");
+ double info = calc_info(instances_by_attr.getDistributionOf(attr_category));
+
+ sum += prob * info;
+ //flog.debug("]} ");
+ }
+ }
+ //flog.debug("\n == "+sum);
+ return sum;
+ }
+
+ /* you can calculate this before */
+ /**
+ * it returns the information value of facts entropy that characterizes the
+ * (im)purity of an arbitrary collection of examples
+ *
+ * @param quantity_by_class the distribution of the instances by the class attribute (target)
+ */
+ public static double calc_info(ClassDistribution quantity_by_class) {
+
+ double data_size = quantity_by_class.getSum();
+
+ double prob, sum = 0;
+ Domain target_domain = quantity_by_class.getClassDomain();
+ if (data_size > 0)
+ for (int category = 0; category<target_domain.getCategoryCount(); category++) {
+
+ Object targetCategory = target_domain.getCategory(category);
+ double num_in_class = quantity_by_class.getVoteFor(targetCategory);
+
+ if (num_in_class > 0) {
+ prob = num_in_class / data_size;
+ /* TODO what if it is a sooo small number ???? */
+ //flog.debug("("+num_in_class+ "/"+data_size+":"+prob+")" +"*"+ Util.log2(prob) + " + ");
+ sum -= prob * Util.log2(prob);
+ }
+ }
+ //flog.debug("= " +sum);
+ return sum;
+ }
+}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/GainRatio.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/GainRatio.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/GainRatio.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,73 @@
+package org.drools.learner.eval.heuristic;
+
+import org.drools.learner.Domain;
+import org.drools.learner.QuantitativeDomain;
+import org.drools.learner.eval.Categorizer;
+import org.drools.learner.eval.CondClassDistribution;
+import org.drools.learner.tools.Util;
+
+public class GainRatio extends Entropy implements Heuristic{
+
+ public GainRatio() {
+ super();
+ }
+
+
+ public double getEval(Domain attr_domain) {
+ CondClassDistribution insts_by_attr = super.info_attr(attr_domain);
+ double info_gain = super.data_eval - Entropy.calc_info_attr(insts_by_attr);
+
+ double split_info = GainRatio.split_info(insts_by_attr);
+
+ System.err.println("(GainRatio) info_gain = "+ info_gain + "/"+ split_info);
+ return info_gain /split_info;
+ }
+
+ public double getEval_cont(Domain attr_domain) {
+
+ double attribute_eval= 0.0d, split_info = 1.0d;
+ QuantitativeDomain trialDomain = QuantitativeDomain.createFromDomain(attr_domain);
+
+ Categorizer visitor = new Categorizer(insts_by_target);
+ visitor.findSplits(trialDomain);
+
+ // trial domain is modified
+ if (trialDomain.getNumIndices() > 1) {
+ CondClassDistribution insts_by_attr = super.info_contattr(visitor);
+ attribute_eval = super.data_eval - Entropy.calc_info_attr(insts_by_attr);
+
+ split_info = GainRatio.split_info(insts_by_attr);
+ }
+ domain = trialDomain;
+ sorted_instances = visitor.getSortedInstances();
+ return attribute_eval / split_info;
+ }
+
+ private static double split_info( CondClassDistribution instances_by_attr) {
+ //Collection<Object> attributeValues = instances_by_attr.getAttributes();
+ double data_size = instances_by_attr.getTotal();
+ double sum = 1.0;
+ if (data_size>0) {
+ for (int attr_idx = 0; attr_idx < instances_by_attr.getNumCondClasses(); attr_idx++) {
+ Object attr_category = instances_by_attr.getCondClass(attr_idx);
+ double num_in_attr = instances_by_attr.getTotal_AttrCategory(attr_category);
+
+ if (num_in_attr > 0.0) {
+ double prob = num_in_attr / data_size;
+ sum -= prob * Util.log2(prob);
+ }
+ }
+ } else {
+ System.err.println("????? data_size = "+ data_size);
+ System.exit(0);
+ }
+
+ //flog.debug("\n == "+sum);
+ return sum;
+ }
+
+
+
+
+
+}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Heuristic.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Heuristic.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/Heuristic.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,32 @@
+package org.drools.learner.eval.heuristic;
+
+import java.util.ArrayList;
+
+import org.drools.learner.Domain;
+import org.drools.learner.Instance;
+import org.drools.learner.eval.InstDistribution;
+
+
+public interface Heuristic {
+
+ public void init(InstDistribution _insts_by_target);
+
+ public double getEval(Domain attr_domain);
+ public double getEval_cont(Domain attr_domain);
+
+ public Domain getDomain();
+ public ArrayList<Instance> getSortedInstances();
+
+ public double getWorstEval();
+
+// public abstract double info_attr(InstDistribution insts_by_target, Domain attr_domain);
+// public abstract double info_contattr(List<Instance> data, Domain targetDomain, QuantitativeDomain splitDomain);
+//
+// public abstract double calc_info_attr( CondClassDistribution instances_by_attr);
+// public abstract double calc_info(ClassDistribution quantity_by_class);
+
+
+
+
+
+}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/MinEntropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/MinEntropy.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/MinEntropy.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,11 @@
+package org.drools.learner.eval.heuristic;
+
+public class MinEntropy extends Entropy implements Heuristic {
+
+ public MinEntropy() {
+ super();
+ super.multiplier = -1.0;
+
+ }
+
+}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/RandomInfo.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/RandomInfo.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/heuristic/RandomInfo.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -0,0 +1,44 @@
+package org.drools.learner.eval.heuristic;
+
+import java.util.Random;
+
+import org.drools.learner.Domain;
+import org.drools.learner.QuantitativeDomain;
+import org.drools.learner.eval.Categorizer;
+import org.drools.learner.eval.CondClassDistribution;
+
+public class RandomInfo extends Entropy implements Heuristic{
+
+ private static Random info_number = new Random(System.currentTimeMillis());
+ public RandomInfo() {
+ super();
+ }
+
+
+ public double getEval(Domain attr_domain) {
+ CondClassDistribution insts_by_attr = super.info_attr(attr_domain);
+ double info_gain = super.data_eval - Entropy.calc_info_attr(insts_by_attr);
+
+ return info_number.nextDouble(); //info_gain;// /split_info;
+ }
+
+ public double getEval_cont(Domain attr_domain) {
+
+ double attribute_eval= 0.0d, split_info = 1.0d;
+ QuantitativeDomain trialDomain = QuantitativeDomain.createFromDomain(attr_domain);
+
+ Categorizer visitor = new Categorizer(insts_by_target);
+ visitor.findSplits(trialDomain);
+
+ domain = trialDomain;
+ sorted_instances = visitor.getSortedInstances();
+ return info_number.nextDouble();//attribute_eval / split_info;
+ }
+
+
+
+
+
+
+
+}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -35,7 +35,7 @@
private static Random BAGGING = new Random(System.currentTimeMillis());
//public static String log_file = "testing.log";
public static double TRAINING_RATIO = 1.0, TESTING_RATIO = 0.0;;
- public static double DEFAULT_TRAINING_RATIO = 0.84, DEFAULT_TESTING_RATIO= 0.16;
+ public static double DEFAULT_TRAINING_RATIO = 0.80, DEFAULT_TESTING_RATIO= 0.20;
public static String ntimes(String s,int n){
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -7,7 +7,7 @@
import org.drools.learner.builder.Learner.DataType;
import org.drools.learner.builder.test.SingleTreeTester;
import org.drools.learner.builder.DecisionTreeFactory;
-import org.drools.learner.eval.Entropy;
+import org.drools.learner.eval.heuristic.Entropy;
import org.drools.learner.tools.FeatureNotSupported;
public class StructuredTestFactory extends DecisionTreeFactory{
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -58,25 +58,60 @@
case 122:
decision_tree = DecisionTreeFactory.createSingleC45G(session, obj_class);
break;
+ case 123:
+ decision_tree = DecisionTreeFactory.createSingleC45E_worst(session, obj_class);
+ break;
+ case 124:
+ decision_tree = DecisionTreeFactory.createSingleC45Random(session, obj_class);
+ break;
+ case 131:
+ decision_tree = DecisionTreeFactory.createSingleC45E_Stop(session, obj_class);
+ break;
+ case 132:
+ decision_tree = DecisionTreeFactory.createSingleC45G_Stop(session, obj_class);
+ break;
+ case 141:
+ decision_tree = DecisionTreeFactory.createSingleC45E_PrunStop(session, obj_class);
+ break;
+ case 142:
+ decision_tree = DecisionTreeFactory.createSingleC45G_PrunStop(session, obj_class);
+ break;
case 221:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
+ case 231:
+ decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
+ break;
+ case 232:
+ decision_tree = DecisionTreeFactory.createBagC45G_Stop(session, obj_class);
+ break;
+ case 241:
+ decision_tree = DecisionTreeFactory.createBagC45E_PrunStop(session, obj_class);
+ break;
+ case 242:
+ decision_tree = DecisionTreeFactory.createBagC45G_PrunStop(session, obj_class);
+ break;
case 321:
decision_tree = DecisionTreeFactory.createBoostC45E(session, obj_class);
- break;
+ break;
case 322:
decision_tree = DecisionTreeFactory.createBoostC45G(session, obj_class);
+ break;
+ case 331:
+ decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
+ break;
+ case 332:
+ decision_tree = DecisionTreeFactory.createBoostC45G_Stop(session, obj_class);
+ break;
+ case 341:
+ decision_tree = DecisionTreeFactory.createBoostC45E_PrunStop(session, obj_class);
+ break;
+ case 342:
+ decision_tree = DecisionTreeFactory.createBoostC45G_PrunStop(session, obj_class);
break;
-
- case 701:
- decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
- break;
- case 702:
- decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
- break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -25,8 +25,8 @@
final StatefulSession session = ruleBase.newStatefulSession(); // LearningSession
// what are these listeners???
- session.addEventListener( new DebugAgendaEventListener() );
- session.addEventListener( new DebugWorkingMemoryEventListener() );
+// session.addEventListener( new DebugAgendaEventListener() );
+// session.addEventListener( new DebugWorkingMemoryEventListener() );
final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
logger.setFileName( "log/nursery" );
@@ -39,7 +39,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 111;
+ DecisionTree decision_tree; int ALGO = 121;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -58,13 +58,60 @@
case 122:
decision_tree = DecisionTreeFactory.createSingleC45G(session, obj_class);
break;
+ case 123:
+ decision_tree = DecisionTreeFactory.createSingleC45E_worst(session, obj_class);
+ break;
+ case 124:
+ decision_tree = DecisionTreeFactory.createSingleC45Random(session, obj_class);
+ break;
+ case 131:
+ decision_tree = DecisionTreeFactory.createSingleC45E_Stop(session, obj_class);
+ break;
+ case 132:
+ decision_tree = DecisionTreeFactory.createSingleC45G_Stop(session, obj_class);
+ break;
+ case 141:
+ decision_tree = DecisionTreeFactory.createSingleC45E_PrunStop(session, obj_class);
+ break;
+ case 142:
+ decision_tree = DecisionTreeFactory.createSingleC45G_PrunStop(session, obj_class);
+ break;
case 221:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
-
+ case 231:
+ decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
+ break;
+ case 232:
+ decision_tree = DecisionTreeFactory.createBagC45G_Stop(session, obj_class);
+ break;
+ case 241:
+ decision_tree = DecisionTreeFactory.createBagC45E_PrunStop(session, obj_class);
+ break;
+ case 242:
+ decision_tree = DecisionTreeFactory.createBagC45G_PrunStop(session, obj_class);
+ break;
+ case 321:
+ decision_tree = DecisionTreeFactory.createBoostC45E(session, obj_class);
+ break;
+ case 322:
+ decision_tree = DecisionTreeFactory.createBoostC45G(session, obj_class);
+ break;
+ case 331:
+ decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
+ break;
+ case 332:
+ decision_tree = DecisionTreeFactory.createBoostC45G_Stop(session, obj_class);
+ break;
+ case 341:
+ decision_tree = DecisionTreeFactory.createBoostC45E_PrunStop(session, obj_class);
+ break;
+ case 342:
+ decision_tree = DecisionTreeFactory.createBoostC45G_PrunStop(session, obj_class);
+ break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
@@ -83,7 +130,7 @@
ruleBase.addPackage( pkg );
*/
- session.fireAllRules();
+// session.fireAllRules();
ReteStatistics stats = new ReteStatistics(ruleBase);
stats.calculateNumberOfNodes();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -30,7 +30,7 @@
@FieldAnnotation(readingSeq = 9, discrete=false)
private int c5; // 'Rank of card #5': Numerical (1-13) representing (Ace, 2, 3, ... , Queen, King)
- @FieldAnnotation(readingSeq = 10, ignore = true)//target=true)
+ @FieldAnnotation(readingSeq = 10, ignore = true)//target=true)//
private int poker_hand;
/*
*0: Nothing in hand; not a recognized poker hand
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -48,7 +48,7 @@
// }
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 141;
+ DecisionTree decision_tree; int ALGO = 141;//241;//341
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -67,6 +67,12 @@
case 122:
decision_tree = DecisionTreeFactory.createSingleC45G(session, obj_class);
break;
+ case 123:
+ decision_tree = DecisionTreeFactory.createSingleC45E_worst(session, obj_class);
+ break;
+ case 124:
+ decision_tree = DecisionTreeFactory.createSingleC45Random(session, obj_class);
+ break;
case 131:
decision_tree = DecisionTreeFactory.createSingleC45E_Stop(session, obj_class);
break;
@@ -120,6 +126,7 @@
}
+
final PackageBuilder builder = new PackageBuilder();
//this wil generate the rules, then parse and compile in one step
builder.addPackageFromTree( decision_tree );
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -79,14 +79,14 @@
* get the compiled package (which is serializable) from the builder
* add the package to a rulebase (deploy the rule package).
*/
- ruleBase.addPackage( builder.getPackage() );
-
- session.fireAllRules();
-
- //session.fireAllRules();
- ReteStatistics stats = new ReteStatistics(ruleBase);
- stats.calculateNumberOfNodes();
- stats.print(Util.DRL_DIRECTORY + decision_tree.getSignature());
+// ruleBase.addPackage( builder.getPackage() );
+//
+// session.fireAllRules();
+//
+// //session.fireAllRules();
+// ReteStatistics stats = new ReteStatistics(ruleBase);
+// stats.calculateNumberOfNodes();
+// stats.print(Util.DRL_DIRECTORY + decision_tree.getSignature());
//logger.writeToDisk();
session.dispose();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -88,7 +88,7 @@
// stats.calculateNumberOfNodes();
// stats.print(Util.DRL_DIRECTORY + decision_tree.getSignature());
// //logger.writeToDisk();
-
- session.dispose();
+//
+// session.dispose();
}
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-11-28 15:25:28 UTC (rev 24137)
@@ -38,7 +38,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 702;
+ DecisionTree decision_tree; int ALGO = 121;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -57,16 +57,63 @@
case 122:
decision_tree = DecisionTreeFactory.createSingleC45G(session, obj_class);
break;
+ case 123:
+ decision_tree = DecisionTreeFactory.createSingleC45E_worst(session, obj_class);
+ break;
+ case 124:
+ decision_tree = DecisionTreeFactory.createSingleC45Random(session, obj_class);
+ break;
+ case 131:
+ decision_tree = DecisionTreeFactory.createSingleC45E_Stop(session, obj_class);
+ break;
+ case 132:
+ decision_tree = DecisionTreeFactory.createSingleC45G_Stop(session, obj_class);
+ break;
+ case 141:
+ decision_tree = DecisionTreeFactory.createSingleC45E_PrunStop(session, obj_class);
+ break;
+ case 142:
+ decision_tree = DecisionTreeFactory.createSingleC45G_PrunStop(session, obj_class);
+ break;
case 221:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
-
+ case 231:
+ decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
+ break;
+ case 232:
+ decision_tree = DecisionTreeFactory.createBagC45G_Stop(session, obj_class);
+ break;
+ case 241:
+ decision_tree = DecisionTreeFactory.createBagC45E_PrunStop(session, obj_class);
+ break;
+ case 242:
+ decision_tree = DecisionTreeFactory.createBagC45G_PrunStop(session, obj_class);
+ break;
+ case 321:
+ decision_tree = DecisionTreeFactory.createBoostC45E(session, obj_class);
+ break;
+ case 322:
+ decision_tree = DecisionTreeFactory.createBoostC45G(session, obj_class);
+ break;
+ case 331:
+ decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
+ break;
+ case 332:
+ decision_tree = DecisionTreeFactory.createBoostC45G_Stop(session, obj_class);
+ break;
+ case 341:
+ decision_tree = DecisionTreeFactory.createBoostC45E_PrunStop(session, obj_class);
+ break;
+ case 342:
+ decision_tree = DecisionTreeFactory.createBoostC45G_PrunStop(session, obj_class);
+ break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
-
+
}
final PackageBuilder builder = new PackageBuilder();
@@ -85,7 +132,7 @@
ReteStatistics stats = new ReteStatistics(ruleBase);
stats.calculateNumberOfNodes();
- stats.print( decision_tree.getSignature());
+ stats.print(Util.DRL_DIRECTORY + decision_tree.getSignature());
// logger.writeToDisk();
session.dispose();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.stats
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.stats 2008-11-28 15:16:20 UTC (rev 24136)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.stats 2008-11-28 15:25:28 UTC (rev 24137)
@@ -1,7 +1,10 @@
-TESTING results: incorrect 0
-TESTING results: correct 1728
-TESTING results: unknown 0
-TESTING results: Total Number 1728
+#ORIGINAL TREE
+
+#INCORRECT CORRECT TOTAL
+
+
+ & 0 & 0 & 1728 & 100 & 1728 & 100 & 0 & ? & 0 & ? & 0 & 100\\
+
& OBJECT_TYPE_NODE & ALPHA_NODE & BETA_NODE & TERMINAL_NODE
& 1 & 278 & 0 & 188\\
More information about the jboss-svn-commits
mailing list