[jboss-svn-commits] JBL Code SVN: r21669 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 5 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Thu Aug 21 21:05:48 EDT 2008
Author: gizil
Date: 2008-08-21 21:05:47 -0400 (Thu, 21 Aug 2008)
New Revision: 21669
Modified:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Memory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Stats.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/StatsPrinter.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/MannersLearnerBenchmark.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/ShoppingExm.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.drl
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.stats
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.drl
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.stats
Log:
last refactoring
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -5,6 +5,7 @@
import java.util.Collections;
import java.util.HashMap;
+import org.drools.learner.eval.TreeStats;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
@@ -30,11 +31,12 @@
private String execution_signature;
public long FACTS_READ = 0;
- private double validation_error, training_error;
+// private double validation_error, training_error;
+ private TreeStats error_stats;
private int num_nonterminal_nodes;
private int trainingDataSize, testDataSize;
- private InstanceList train, test;
+ //private InstanceList train, test;
public DecisionTree(Schema inst_schema, String _target) {
this.obj_schema = inst_schema; //inst_schema.getObjectClass();
@@ -57,24 +59,20 @@
}
}
Collections.sort(this.attrsToClassify, new DomainComparator()); // compare the domains by the name
-
+ error_stats = new TreeStats(0.0d, 0.0d);
}
public DecisionTree(DecisionTree parentTree, Domain exceptDomain) {
//this.domainSet = new Hashtable<String, Domain<?>>();
this.obj_schema = parentTree.getSchema();
this.target = parentTree.getTargetDomain();
+ this.error_stats = parentTree.error_stats;
+
this.attrsToClassify = new ArrayList<Domain>(parentTree.getAttrDomains().size()-1);
for (Domain attr_domain : parentTree.getAttrDomains()) {
if (attr_domain.isNotJustSelected(exceptDomain))
this.attrsToClassify.add(attr_domain);
- }
-// System.out.print("New tree ");
-// for (Domain d:attrsToClassify)
-// System.out.print("d: "+d);
-// System.out.println("");
- //Collections.sort(this.attrsToClassify, new Comparator<Domain>()); // compare the domains by the name
-
+ }
}
private Schema getSchema() {
@@ -117,21 +115,39 @@
return this.getRoot().voteFor(i);
}
-
- public void setValidationError(double error) {
- validation_error = error;
- }
- public double getValidationError() {
- return validation_error;
+ public TreeStats getStats() {
+ return error_stats;
}
+ public void changeTestError(double error) {
+ error_stats.setErrorEstimation(error_stats.getErrorEstimation() + error);
+ }
- public void setTrainingError(double error) {
- training_error = error;
+ public void changeTrainError(double error) {
+ error_stats.setTrainError(error_stats.getTrainError() + error);
}
- public double getTrainingError() {
- return training_error;
+
+ public double getTestError() {
+ return error_stats.getErrorEstimation();
}
+ public double getTrainError() {
+ return error_stats.getTrainError();
+ }
+
+// public void setValidationError(double error) {
+// validation_error = error;
+// }
+// public double getValidationError() {
+// return validation_error;
+// }
+//
+// public void setTrainingError(double error) {
+// training_error = error;
+// }
+// public double getTrainingError() {
+// return training_error;
+// }
+
public int calc_num_node_leaves(TreeNode my_node) {
if (my_node instanceof LeafNode) {
@@ -248,20 +264,22 @@
return testDataSize;
}
- public void setTrain(InstanceList x) {
- train = x;
- }
-
- public void setTest(InstanceList x) {
- test = x;
- }
+// public void setTrain(InstanceList x) {
+// train = x;
+// }
+//
+// public void setTest(InstanceList x) {
+// test = x;
+// }
+//
+// public InstanceList getTrain() {
+// return train;
+// }
+//
+// public InstanceList getTest() {
+// return test;
+// }
- public InstanceList getTrain() {
- return train;
- }
+
- public InstanceList getTest() {
- return test;
- }
-
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -2,9 +2,11 @@
import java.util.ArrayList;
-import org.drools.learner.builder.SingleTreeTester;
-import org.drools.learner.eval.ErrorEstimate;
-import org.drools.learner.eval.TestSample;
+import org.drools.learner.builder.Solution;
+import org.drools.learner.builder.test.SingleTreeTester;
+//import org.drools.learner.eval.ErrorEstimate;
+//import org.drools.learner.eval.TestSample;
+import org.drools.learner.eval.PrunerStats;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
@@ -16,282 +18,100 @@
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(DecisionTreePruner.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(DecisionTreePruner.class, SimpleLogger.DEBUG);
- private ErrorEstimate procedure;
+ private PrunerStats best_stats;
- private TreeStats best_stats;
-
- private int num_trees_to_grow;
-
private double INIT_ALPHA = 0.5d;
private static final double EPSILON = 0.0d;//0.0000000001;
- private int best_id = 0;;
- public DecisionTreePruner(ErrorEstimate proc) {
- procedure = proc;
- num_trees_to_grow = procedure.getEstimatorSize();
- //updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
-
- best_stats = new TreeStats(0.0);//proc.getAlphaEstimate());
- }
-
+ private double best_error;
+ private Solution best_solution;
+ ArrayList<Solution> pruned_sol;
+
public DecisionTreePruner() {
- procedure = null;
- num_trees_to_grow = 1;
- //num_trees_to_grow = procedure.getEstimatorSize();
- //updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- procedure = new TestSample(0.2d);
- best_stats = new TreeStats(0.0);//proc.getAlphaEstimate());
+ best_error = 1.0;
+
+ best_stats = new PrunerStats(0.0);//proc.getAlphaEstimate());
+ pruned_sol = new ArrayList<Solution>();
}
-
- public void prun_to_estimate(ArrayList<DecisionTree> dts) {
- ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
- ArrayList<MinAlphaProc> alpha_procs = new ArrayList<MinAlphaProc>(num_trees_to_grow);
+ public Solution getBestSolution() {
+ return best_solution;
+ }
+ public void prun_to_estimate(Solution sol) {
/*
* The best tree is selected from this series of trees with the classification error not exceeding
* an expected error rate on some test set (cross-validation error),
* which is done at the second stage.
*/
-// double value_to_select = procedure.getErrorEstimate();
- for (int dt_i = 0; dt_i<dts.size(); dt_i++) {
-
- DecisionTree dt= dts.get(dt_i);
- procedure.setTrainingDataSize(dt.getTrainingDataSize());
- dt.setID(dt_i);
- dt.calc_num_node_leaves(dt.getRoot());
-
-
+ DecisionTree dt = sol.getTree();
+ dt.calc_num_node_leaves(dt.getRoot());
+
- // dt.getId()
- // dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
+ double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
+ MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
+ TreeSequenceProc search = new TreeSequenceProc(sol, alpha_proc);//INIT_ALPHA
- double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
- MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
- TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
+ search.init_tree(); // alpha_1 = 0.0
+ search.iterate_trees(1);
+
+// updates.add(search.getTreeSequence());
+// equence_stats.add(search.getTreeSequenceStats());
+// alpha_procs.add(alpha_proc);
- search.init_tree(); // alpha_1 = 0.0
- search.iterate_trees(1);
+ boolean better_found = false;
- //updates.add(tree_sequence);
- updates.add(search.getTreeSequence());
- sequence_stats.add(search.getTreeSequenceStats());
- alpha_procs.add(alpha_proc);
-
- // sort the found candidates
- //Collections.sort(updates.get(dt.getId()), arg1)
- int sid = 0, best_st=0;
- best_id = 0;
- double best_error = 1.0;
- System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- ArrayList<TreeStats> trees = sequence_stats.get(dt.getId());
- for (TreeStats st: trees ){
- if (st.getCostEstimation() <= best_error) {
- best_error = st.getCostEstimation();
- best_id = sid;
- best_st = st.iteration_id;
- }
- System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
- sid++;
+ int sid = 0, best_st=0;
+ int best_id = 0;
+ double best_error = 1.0d;
+ System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
+ ArrayList<PrunerStats> trees = search._getTreeSequenceStats();
+ for (PrunerStats st: trees ){
+ if (st.getErrorEstimation() <= best_error) {
+ best_error = st.getErrorEstimation();
+ best_id = sid;
+ best_st = st.iteration_id();
+ better_found = true;
}
- System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getCostEstimation() + "\t "+ trees.get(best_id).getResubstitution_cost() + "\t "+ trees.get(best_id).getAlpha() );
-
- // System.exit(0);
- //int N = procedure.getTestDataSize(dt_i);
- int N = dt.getTestingDataSize();// procedure.getTestDataSize(dt_i);
- double standart_error_estimate = Math.sqrt((trees.get(best_id).getCostEstimation() * (1- trees.get(best_id).getCostEstimation())/ (double)N));
-
- int update_id = search.tree_sequence.size()-1;
- int i = trees.get(trees.size()-1).iteration_id;
- int _sid = trees.size()-1;
- while(i > best_st-1) {
- NodeUpdate nu =search.tree_sequence.get(update_id);
- while (i == nu.iteration_id) {
- search.add_back(nu.node_update, nu.old_node);
- update_id --;
- nu = search.tree_sequence.get(update_id);
- }
- _sid--;
- TreeStats st = trees.get(_sid);
- i = st.iteration_id;
- System.out.println(_sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
-
- }
-// System.exit(0);
+ System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getErrorEstimation() + "\t "+ st.getTrainError()+ "\t "+ st.getAlpha() );
+ sid++;
}
+ System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getErrorEstimation() + "\t "+ trees.get(best_id).getTrainError() + "\t "+ trees.get(best_id).getAlpha() );
-
- }
-
- public DecisionTree prun_to_estimate() {
- ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
- ArrayList<MinAlphaProc> alpha_procs = new ArrayList<MinAlphaProc>(num_trees_to_grow);
+// System.exit(0);
- /*
- * The best tree is selected from this series of trees with the classification error not exceeding
- * an expected error rate on some test set (cross-validation error),
- * which is done at the second stage.
- */
- double value_to_select = procedure.getErrorEstimate();
+ int N = dt.getTestingDataSize();// procedure.getTestDataSize(dt_i);
+ double standart_error_estimate = Math.sqrt((trees.get(best_id).getErrorEstimation() * (1- trees.get(best_id).getErrorEstimation())/ (double)N));
- for (int dt_i = 0; dt_i<procedure.getEstimatorSize(); dt_i++) {
- DecisionTree dt= procedure.getEstimator(dt_i);
-
-
- // dt.getId()
- // dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
-
- double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
- MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
- TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
-
- search.init_tree(); // alpha_1 = 0.0
- search.iterate_trees(1);
-
- //updates.add(tree_sequence);
- updates.add(search.getTreeSequence());
- sequence_stats.add(search.getTreeSequenceStats());
- alpha_procs.add(alpha_proc);
-
- // sort the found candidates
- //Collections.sort(updates.get(dt.getId()), arg1)
- int sid = 0, best_st = 0;
- best_id = 0;
- double best_error = 1.0;
- System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- ArrayList<TreeStats> trees = sequence_stats.get(dt.getId());
- for (TreeStats st: trees ){
- if (st.getCostEstimation() <= best_error) {
- best_error = st.getCostEstimation();
- best_id = sid;
- best_st = st.iteration_id;
- }
- System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
- sid++;
- }
- System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getCostEstimation() + "\t "+ trees.get(best_id).getResubstitution_cost() + "\t "+ trees.get(best_id).getAlpha() );
+ if (better_found) {
+ for (int i = search.tree_sequence.size()-1; i>0; i--) {
- // System.exit(0);
- int N = procedure.getTestDataSize(dt_i);
- double standart_error_estimate = Math.sqrt((trees.get(best_id).getCostEstimation() * (1- trees.get(best_id).getCostEstimation())/ (double)N));
-
- int update_id = search.tree_sequence.size()-1;
- int i = trees.get(trees.size()-1).iteration_id;
- int _sid = trees.size()-1;
- while(i > best_st) {
- NodeUpdate nu =search.tree_sequence.get(update_id);
- while (i == nu.iteration_id) {
+ NodeUpdate nu = search.tree_sequence.get(i);
+ System.out.println(nu.iteration_id +">"+ best_st);
+ if (nu.iteration_id > best_st ) {
search.add_back(nu.node_update, nu.old_node);
- update_id --;
- nu = search.tree_sequence.get(update_id);
+ SingleTreeTester t = new SingleTreeTester(search.tree_sol.getTree());
+ Stats train = t.test(search.tree_sol.getList());
+ Stats test = t.test(search.tree_sol.getTestList());
+ int num_leaves = search.tree_sol.getTree().getRoot().getNumLeaves();
+
+ System.out.println("Back "+ i+ "\t " +num_leaves+ "\t "+ test.getResult(Stats.INCORRECT)/(double)test.getTotal() + "\t "+ train.getResult(Stats.INCORRECT)/(double)train.getTotal()+ "\t ");
+ } else {
+ break;
}
- _sid--;
- TreeStats st = trees.get(_sid);
- i = st.iteration_id;
- System.out.println(_sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
-
}
-// System.exit(0);
- return dt;
+ best_solution = search.tree_sol;
}
- return null;
+
-
-
}
-
-
-// public DecisionTree prun_to_estimate(DecisionTree dt) {
-// ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
-// ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
-// ArrayList<MinAlphaProc> alpha_procs = new ArrayList<MinAlphaProc>(num_trees_to_grow);
-//
-// /*
-// * The best tree is selected from this series of trees with the classification error not exceeding
-// * an expected error rate on some test set (cross-validation error),
-// * which is done at the second stage.
-// */
-// double value_to_select = procedure.getErrorEstimate();
-//
-// for (int dt_i = 0; dt_i<procedure.getEstimatorSize(); dt_i++) {
-// DecisionTree dt= procedure.getEstimator(dt_i);
-//
-//
-// // dt.getId()
-// // dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
-//
-// double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
-// MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
-// TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
-//
-// search.init_tree(); // alpha_1 = 0.0
-// search.iterate_trees(1);
-//
-// //updates.add(tree_sequence);
-// updates.add(search.getTreeSequence());
-// sequence_stats.add(search.getTreeSequenceStats());
-// alpha_procs.add(alpha_proc);
-//
-// // sort the found candidates
-// //Collections.sort(updates.get(dt.getId()), arg1)
-// int sid = 0;
-// best_id = 0;
-// double best_error = 1.0;
-// System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
-// ArrayList<TreeStats> trees = sequence_stats.get(dt.getId());
-// for (TreeStats st: trees ){
-// if (st.getCostEstimation() <= best_error) {
-// best_error = st.getCostEstimation();
-// best_id = sid;
-// }
-// System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
-// sid++;
-// }
-// System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getCostEstimation() + "\t "+ trees.get(best_id).getResubstitution_cost() + "\t "+ trees.get(best_id).getAlpha() );
-//
-// // System.exit(0);
-// int N = procedure.getTestDataSize(dt_i);
-// double standart_error_estimate = Math.sqrt((trees.get(best_id).getCostEstimation() * (1- trees.get(best_id).getCostEstimation())/ (double)N));
-//
-// int update_id = search.tree_sequence.size()-1;
-// for (int i = trees.size()-1; i>=best_id; i--){
-// NodeUpdate nu =search.tree_sequence.get(update_id);
-// while (i == nu.iteration_id) {
-// search.add_back(nu.node_update, nu.old_node);
-// update_id ++;
-// nu = search.tree_sequence.get(update_id);
-// }
-// TreeStats st = trees.get(i);
-// System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
-// sid++;
-// }
-//
-// return dt;
-// }
-// return null;
-//
-//
-//
-// }
- public void select_tree () {
- /*
- * The best tree is selected from this series of trees with the classification error not exceeding
- * an expected error rate on some test set (cross-validation error),
- * which is done at the second stage.
- */
- double value_to_select = procedure.getErrorEstimate();
-
- }
-
- public void prun_tree(DecisionTree tree) {
- double epsilon = 0.0000001 * numExtraMisClassIfPrun(tree.getRoot());
- TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha(), epsilon));
+ public void prun_tree(Solution sol) {
+ double epsilon = 0.0000001 * numExtraMisClassIfPrun(sol.getTree().getRoot());
+ TreeSequenceProc search = new TreeSequenceProc(sol, new AnAlphaProc(best_stats.getAlpha(), epsilon));
search.iterate_trees(0);
//search.getTreeSequence()// to go back
@@ -307,7 +127,9 @@
// returns the node missclassification cost
private int R(TreeNode t) {
- return (int) t.getNumMatch() - t.getNumLabeled();
+ if (slog.debug() !=null)
+ slog.debug().log(":R:num_misclassified "+ t.getNumMatch() + " "+ t.getNumLabeled());
+ return (int) (t.getNumMatch() - t.getNumLabeled());
}
private int numExtraMisClassIfPrun(TreeNode my_node) {
@@ -327,7 +149,8 @@
slog.debug().log("\n");
if (num_misclassified < kids_misclassified) {
System.out.println("Problem ++++++");
- System.exit(0);
+ //System.exit(0);
+ return 0;
}
return num_misclassified - kids_misclassified;
}
@@ -337,41 +160,56 @@
private static final double MAX_ERROR_RATIO = 0.99;
- private DecisionTree focus_tree;
+ private Solution tree_sol;
//private double the_alpha;
private AlphaSelectionProc alpha_proc;
private ArrayList<NodeUpdate> tree_sequence;
- private ArrayList<TreeStats> tree_sequence_stats;
+// private ArrayList<PrunerStats> tree_sequence_stats;
+ private ArrayList<PrunerStats> _tree_sequence_stats;
- private TreeStats best_tree_stats;
- public TreeSequenceProc(DecisionTree dt, AlphaSelectionProc cond) { //, double init_alpha
- focus_tree = dt;
+ private PrunerStats best_tree_stats;
+ public TreeSequenceProc(Solution sol, AlphaSelectionProc cond) { //, double init_alpha
+ tree_sol = sol;
alpha_proc = cond;
tree_sequence = new ArrayList<NodeUpdate>();
- tree_sequence_stats = new ArrayList<TreeStats>();
+// tree_sequence_stats = new ArrayList<PrunerStats>();
+ _tree_sequence_stats = new ArrayList<PrunerStats>();
- best_tree_stats = new TreeStats(10000000.0d);
+ best_tree_stats = new PrunerStats(10000000.0d);
- NodeUpdate init_tree = new NodeUpdate(dt.getValidationError());
+
+ System.out.println("From solution:"+tree_sol.getTestError()+ " "+ tree_sol.getTrainError() + " " +tree_sol.getTree().getRoot().getNumLeaves());
+ System.out.println("From tree:"+tree_sol.getTree().getTestError()+ " "+ tree_sol.getTree().getTrainError() + " " +tree_sol.getTree().getRoot().getNumLeaves());
+
+
+ NodeUpdate init_tree = new NodeUpdate(tree_sol.getTestError());
tree_sequence.add(init_tree);
- TreeStats init_tree_stats = new TreeStats(dt.getValidationError());
- init_tree_stats.setResubstitution_cost(dt.getTrainingError());
- init_tree_stats.setAlpha(0.0d); // dont know
- init_tree_stats.setCost_complexity(-1); // dont known
-// init_tree_stats.setDecisionTree(dt);
- init_tree_stats.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
- tree_sequence_stats.add(init_tree_stats);
+// PrunerStats init_tree_stats = new PrunerStats(tree_sol.getTree().getStats());
+// init_tree_stats.setNum_terminal_nodes(tree_sol.getTree().getRoot().getNumLeaves());
+// init_tree_stats.setAlpha(0.0d); // dont know
+// init_tree_stats.setCost_complexity(-1); // dont known
+// tree_sequence_stats.add(init_tree_stats);
+ PrunerStats _init_tree_stats = new PrunerStats(tree_sol.getTrainError(), tree_sol.getTestError());
+ _init_tree_stats.setNum_terminal_nodes(tree_sol.getTree().getRoot().getNumLeaves());
+ _init_tree_stats.setAlpha(0.0d); // dont know
+ _init_tree_stats.setCost_complexity(-1); // dont known
+ _tree_sequence_stats.add(_init_tree_stats);
+
}
public DecisionTree getFocusTree() {
- return focus_tree;
+ return tree_sol.getTree();
}
- public ArrayList<TreeStats> getTreeSequenceStats() {
- return tree_sequence_stats;
+// public ArrayList<PrunerStats> getTreeSequenceStats() {
+// return tree_sequence_stats;
+// }
+
+ public ArrayList<PrunerStats> _getTreeSequenceStats() {
+ return _tree_sequence_stats;
}
public ArrayList<NodeUpdate> getTreeSequence() {
@@ -382,19 +220,13 @@
// initialize the tree to be prunned
// T_1 is the smallest subtree of T_max satisfying R(T_1) = R(T_max)
- ArrayList<TreeNode> last_nonterminals = focus_tree.getAnchestor_of_Leaves(focus_tree.getRoot());
+ ArrayList<TreeNode> last_nonterminals = tree_sol.getTree().getAnchestor_of_Leaves(tree_sol.getTree().getRoot());
// R(t) <= sum R(t_children) by the proposition 2.14, R(t) node miss-classification cost
// if R(t) = sum R(t_children) then prune off t_children
boolean tree_changed = false; // (k)
for (TreeNode t: last_nonterminals) {
-// int R_children = 0;
-// for (Object child_key: t.getChildrenKeys()) {
-// TreeNode child = t.getChild(child_key);
-// R_children += child.getMissClassified();
-// }
-// if (t.getMissClassified() == R_children) {
if (numExtraMisClassIfPrun(t) == 0) {
// prune off the candidate node
tree_changed = true;
@@ -403,25 +235,26 @@
}
if (tree_changed) {
- TreeStats stats = new TreeStats();
+ PrunerStats stats = new PrunerStats();
stats.iteration_id(0);
- update_tree_stats(stats, 0.0d, 0); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+ //update_tree_stats(stats, 0.0d, 0); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+ _update_tree_stats(stats, 0.0d, 0);
// tree_sequence_stats.add(stats);
}
}
private void iterate_trees(int i) {
if (slog.debug() !=null)
- slog.debug().log(focus_tree.toString() +"\n");
+ slog.debug().log(tree_sol.getTree().toString() +"\n");
// if number of non-terminal nodes in the tree is more than 1
// = if there exists at least one non-terminal node different than root
- if (focus_tree.getNumNonTerminalNodes() < 1) {
+ if (tree_sol.getTree().getNumNonTerminalNodes() < 1) {
if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + focus_tree.getNumNonTerminalNodes() +"\n");
+ slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + tree_sol.getTree().getNumNonTerminalNodes() +"\n");
return;
- } else if (focus_tree.getNumNonTerminalNodes() == 1 && focus_tree.getRoot().getNumLeaves()<=1) {
+ } else if (tree_sol.getTree().getNumNonTerminalNodes() == 1 && tree_sol.getTree().getRoot().getNumLeaves()<=1) {
if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + focus_tree.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +focus_tree.getRoot().getNumLeaves()+"\n");
+ slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + tree_sol.getTree().getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +tree_sol.getTree().getRoot().getNumLeaves()+"\n");
return;
}
// for each non-leaf subtree
@@ -430,10 +263,10 @@
//TreeSequenceProc search = new TreeSequenceProc(the_alpha, alpha_proc);//100000.0d, new MinAlphaProc());
- find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
+ find_candidate_nodes(tree_sol.getTree().getRoot(), candidate_nodes);
//double min_alpha = search.getTheAlpha();
double min_alpha = getTheAlpha();
- System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+ System.out.println("!!!!!!!!!!! dt:"+tree_sol.getTree().getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
if (candidate_nodes.size() >0) {
// The one or more subtrees with that value of will be replaced by leaves
@@ -449,17 +282,19 @@
change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
}
- TreeStats stats = new TreeStats();
+// PrunerStats stats = new PrunerStats();
+// stats.iteration_id(i);
+// update_tree_stats(stats, min_alpha, change_in_training_misclass); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+ PrunerStats stats = new PrunerStats();
stats.iteration_id(i);
- update_tree_stats(stats, min_alpha, change_in_training_misclass); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
-
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:error "+ stats.getCostEstimation() +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
+ _update_tree_stats(stats, min_alpha, change_in_training_misclass);
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:error "+ stats.getCostEstimation() +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
- if (stats.getCostEstimation() < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
+ if (stats.getErrorEstimation() < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
// if the error of the tree is not that bad
- if (stats.getCostEstimation() < best_tree_stats.getCostEstimation()) {
+ if (stats.getErrorEstimation() < best_tree_stats.getErrorEstimation()) {
best_tree_stats = stats;
if (slog.debug() !=null)
slog.debug().log(":sequence_trees:best node updated \n");
@@ -482,8 +317,8 @@
private void prune_off(TreeNode candidate_node, int i) {
-
- LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
+ System.out.println(tree_sol.getTree().getTargetDomain() + " " + candidate_node.getLabel());
+ LeafNode best_clone = new LeafNode(tree_sol.getTree().getTargetDomain(), candidate_node.getLabel());
best_clone.setRank( candidate_node.getRank());
best_clone.setNumMatch(candidate_node.getNumMatch()); //num of matching instances to the leaf node
best_clone.setNumClassification(candidate_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
@@ -492,7 +327,7 @@
//TODO
update.iteration_id(i);
- update.setDecisionTree(focus_tree);
+ update.setDecisionTree(tree_sol.getTree());
//change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
int num_leaves = candidate_node.getNumLeaves();
@@ -501,13 +336,14 @@
for(Object key: father_node.getChildrenKeys()) {
if (father_node.getChild(key).equals(candidate_node)) {
father_node.putNode(key, best_clone);
+ best_clone.setFather(father_node);
break;
}
}
updateLeaves(father_node, -num_leaves+1);
} else {
// this node does not have any father node it is the root node of the tree
- focus_tree.setRoot(best_clone);
+ tree_sol.getTree().setRoot(best_clone);
}
//updates.get(dt_0.getId()).add(update);
@@ -517,29 +353,31 @@
public void add_back(TreeNode leaf_node, TreeNode original_node) {
TreeNode father = leaf_node.getFather();
if (father == null) {
- focus_tree.setRoot(original_node);
+ tree_sol.getTree().setRoot(original_node);
return;
}
- else
- for(Object key: father.getChildrenKeys()) {
- if (father.getChild(key).equals(leaf_node)) {
- father.putNode(key, original_node);
- break;
+ else {
+ int num_leaves = original_node.getNumLeaves();
+ for(Object key: father.getChildrenKeys()) {
+ if (father.getChild(key).equals(leaf_node)) {
+ father.putNode(key, original_node);
+ original_node.setFather(father);
+ updateLeaves(father, num_leaves-1);
+ break;
+ }
}
}
}
- private void update_tree_stats(TreeStats stats, double computed_alpha, int change_in_training_error) {
- //TODO put back
- // ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
-// InstanceList learning_set = sets.get(0);
-// InstanceList validation_set = sets.get(1);
+
+ private void _update_tree_stats(PrunerStats stats, double computed_alpha, int change_in_training_error) {
- InstanceList learning_set = focus_tree.getTrain();
- InstanceList validation_set = focus_tree.getTest();
+ InstanceList learning_set = tree_sol.getList();
+ InstanceList validation_set = tree_sol.getTestList();
int num_error = 0;
- SingleTreeTester t= new SingleTreeTester(focus_tree);
+ SingleTreeTester t= new SingleTreeTester(tree_sol.getTree());
+ Stats test_stats = t.test(tree_sol.getTestList());
//System.out.println(validation_set.getSize());
for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
@@ -549,34 +387,25 @@
}
}
double percent_error = Util.division(num_error, validation_set.getSize());
-// int _num_error = 0;
-// for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
-// Integer result = t.test(learning_set.getInstance(index_i));
-// if (result == Stats.INCORRECT) {
-// _num_error ++;
-// }
-// }
-// double _percent_error = Util.division(_num_error, learning_set.getSize());
-
- int new_num_leaves = focus_tree.getRoot().getNumLeaves();
+ System.out.println("From test: "+ test_stats.getResult(Stats.INCORRECT)/(double)test_stats.getTotal() + " x "+ percent_error);
+ int new_num_leaves = tree_sol.getTree().getRoot().getNumLeaves();
- double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_error, procedure.getTrainingDataSize(focus_tree.getId()));
- focus_tree.setTrainingError(new_resubstitution_cost);
+ //double new_resubstitution_cost = tree_sol.getTree().getTrainError() + Util.division(change_in_training_error, learning_set.getSize());
+ tree_sol.changeTrainError(change_in_training_error);//setTrainingError(new_resubstitution_cost);
+ //tree_sol.setError();
+ double cost_complexity = tree_sol.getTrainError() + computed_alpha * (new_num_leaves);
- double cost_complexity = new_resubstitution_cost + computed_alpha * (new_num_leaves);
-
if (slog.debug() !=null)
slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
stats.setAlpha(computed_alpha);
- stats.setCostEstimation(percent_error);
- //stats.setTrainingCost(_percent_error);
- stats.setResubstitution_cost(new_resubstitution_cost);
+ stats.setErrorEstimation(percent_error);
+ stats.setTrainError(tree_sol.getTrainError());
// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
stats.setCost_complexity(cost_complexity);
stats.setNum_terminal_nodes(new_num_leaves);
- tree_sequence_stats.add(stats);
+ _tree_sequence_stats.add(stats);
}
// memory optimized
@@ -597,10 +426,10 @@
slog.debug().log(":search_alphas:k == 0\n" );
}
- double num_training_data = (double)procedure.getTrainingDataSize(focus_tree.getId());
+ double num_training_data = (double)tree_sol.getList().getSize();
double alpha = ( (double)k / num_training_data) /((double)(num_leaves-1));
if (slog.debug() !=null)
- slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) + "\n");
+ slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ tree_sol.getList().getSize() + "\n");
//the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
alpha_proc.check_node(alpha, my_node, nodes);
@@ -618,147 +447,7 @@
return alpha_proc.getAlpha();
}
-
-
-
- private void iterate_trees_(int i) {
- if (slog.debug() !=null)
- slog.debug().log(focus_tree.toString() +"\n");
- // if number of non-terminal nodes in the tree is more than 1
- // = if there exists at least one non-terminal node different than root
- if (focus_tree.getNumNonTerminalNodes() < 1) {
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + focus_tree.getNumNonTerminalNodes() +"\n");
- return;
- } else if (focus_tree.getNumNonTerminalNodes() == 1 && focus_tree.getRoot().getNumLeaves()<=1) {
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + focus_tree.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +focus_tree.getRoot().getNumLeaves()+"\n");
- return;
- }
- // for each non-leaf subtree
- ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
- // to find all candidates with min_alpha value
- //TreeSequenceProc search = new TreeSequenceProc(the_alpha, alpha_proc);//100000.0d, new MinAlphaProc());
-
-
- find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
- //double min_alpha = search.getTheAlpha();
- double min_alpha = getTheAlpha();
- System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
-
- if (candidate_nodes.size() >0) {
- // The one or more subtrees with that value of will be replaced by leaves
- // instead of getting the first node, have to process all nodes and prune all
- // write a method to prune all
- //TreeNode best_node = candidate_nodes.get(0);
- TreeStats stats = new TreeStats();
- stats.iteration_id(i);
-
- int change_in_training_misclass = 0; // (k)
- for (TreeNode candidate_node:candidate_nodes) {
- // prune off the candidate node
- //prune_off(candidate_node, i);
- //change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
-/* */
- LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
- best_clone.setRank( candidate_node.getRank());
- best_clone.setNumMatch(candidate_node.getNumMatch()); //num of matching instances to the leaf node
- best_clone.setNumClassification(candidate_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
-
- NodeUpdate update = new NodeUpdate(candidate_node, best_clone);
- update.iteration_id(i);
-
- update.setDecisionTree(focus_tree);
-
- int num_leaves = candidate_node.getNumLeaves();
- TreeNode father_node = candidate_node.getFather();
- if (father_node != null) {
- for(Object key: father_node.getChildrenKeys()) {
- if (father_node.getChild(key).equals(candidate_node)) {
- father_node.putNode(key, best_clone);
- break;
- }
- }
- updateLeaves(father_node, -num_leaves+1);
- } else {
- // this node does not have any father node it is the root node of the tree
- focus_tree.setRoot(best_clone);
- }
-
- //updates.get(dt_0.getId()).add(update);
- tree_sequence.add(update);
-/**/
- }
-/**/
- ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
- InstanceList learning_set = sets.get(0);
- InstanceList validation_set = sets.get(1);
-
- int num_error = 0;
- SingleTreeTester t= new SingleTreeTester(focus_tree);
-
- System.out.println(validation_set.getSize());
- for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
- Integer result = t.test(validation_set.getInstance(index_i));
- if (result == Stats.INCORRECT) {
- num_error ++;
- }
- }
- double percent_error = Util.division(num_error, validation_set.getSize());
-// int _num_error = 0;
-// for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
-// Integer result = t.test(learning_set.getInstance(index_i));
-// if (result == Stats.INCORRECT) {
-// _num_error ++;
-// }
-// }
-// double _percent_error = Util.division(_num_error, learning_set.getSize());
-
- int new_num_leaves = focus_tree.getRoot().getNumLeaves();
-
- double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId()));
- double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
-
-
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
-
-
- stats.setAlpha(min_alpha);
- stats.setCostEstimation(percent_error);
-// stats.setTrainingCost(_percent_error);
- stats.setResubstitution_cost(new_resubstitution_cost);
- // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
- stats.setCost_complexity(cost_complexity);
- stats.setNum_terminal_nodes(new_num_leaves);
-
-/**/
- tree_sequence_stats.add(stats);
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
-
- if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
- // if the error of the tree is not that bad
-
- if (percent_error < best_tree_stats.getCostEstimation()) {
- best_tree_stats = stats;
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:best node updated \n");
-
- }
- // TODO update alpha_proc by increasing the min_alpha the_alpha += 10;
- alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
- iterate_trees_(i+1);
- } else {
- //TODO update.setStopTree();
- return;
- }
- } else {
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:no candidate node is found ???? \n");
- }
-
- }
+
}
@@ -906,78 +595,7 @@
}
- public class TreeStats{
-
- private int iteration_id;
- private int num_terminal_nodes;
- private double test_cost;
- private double resubstitution_cost;
- private double cost_complexity;
- private double alpha;
-// private double training_error;
-
- public TreeStats() {
- iteration_id = 0;
- }
-
- // to set an node update with the worst cross validated error
- public TreeStats(double error) {
- iteration_id = 0;
- test_cost = error;
- }
-
- public void iteration_id(int i) {
- iteration_id = i;
- }
- public int getNum_terminal_nodes() {
- return num_terminal_nodes;
- }
-
- public void setNum_terminal_nodes(int num_terminal_nodes) {
- this.num_terminal_nodes = num_terminal_nodes;
- }
-
- public double getCostEstimation() {
- return test_cost;
- }
-
- public void setCostEstimation(double valid_cost) {
- this.test_cost = valid_cost;
- }
-
- public double getResubstitution_cost() {
- return resubstitution_cost;
- }
-
- public void setResubstitution_cost(double resubstitution_cost) {
- this.resubstitution_cost = resubstitution_cost;
- }
-
- public double getCost_complexity() {
- return cost_complexity;
- }
-
- public void setCost_complexity(double cost_complexity) {
- this.cost_complexity = cost_complexity;
- }
-
- public double getAlpha() {
- return alpha;
- }
-
- public void setAlpha(double alpha) {
- this.alpha = alpha;
- }
-
-// public void setTrainingCost(double _percent_error) {
-// training_error = _percent_error;
-// }
-//
-// public double getTrainingCost() {
-// return training_error;
-// }
-
- }
+
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Memory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Memory.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Memory.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -7,6 +7,7 @@
import org.drools.learner.builder.Learner.DataType;
import org.drools.learner.builder.Learner.DomainAlgo;
import org.drools.learner.tools.FeatureNotSupported;
+import org.drools.learner.tools.Util;
public class Memory {
@@ -31,8 +32,6 @@
// create a instance list that can hold objects from our schema
mem.instances.put(clazz, new InstanceList(inst_schema, _session));
-
-// mem.test_instances.put(clazz, new InstanceList(inst_schema, _session));
/*
* do they create an ObjectTypeNode for each new inserted object type?
@@ -46,50 +45,12 @@
//if (clazz.isAssignableFrom(obj.getClass()))
mem.instances.get(clazz).addStructuredInstance(obj);
}
- //dt.FACTS_READ += facts.size();
return mem;
}
- public static Memory createTestFromWorkingMemory(Memory old_memory, WorkingMemory _session, Class<?> clazz, DomainAlgo domain, DataType data) throws FeatureNotSupported {
- // if mem == null
- Memory mem = new Memory();
- mem.session = _session;
- mem.setClassToClassify(clazz);
-
- // create schema from clazz
- Schema inst_schema = old_memory.getClassInstances().getSchema();
-// try {
-// inst_schema = Schema.createSchemaStructure(clazz, domain, data);
-// } catch (Exception e) {
-// // TODO Auto-generated catch block
-// e.printStackTrace();
-// System.exit(0);
-// }
-
- // create a instance list that can hold objects from our schema
- mem.instances.put(clazz, new InstanceList(inst_schema, _session));
-
- /*
- * do they create an ObjectTypeNode for each new inserted object type?
- * even if there is no rule exists.
- * No probably they do not
- */
- Iterator<Object> it_object = _session.iterateObjects(); // how can i get the object type nodes
- while (it_object.hasNext()) {
- Object obj = it_object.next();
- // validating in the the factory during instantiation
- //if (clazz.isAssignableFrom(obj.getClass()))
- mem.instances.get(clazz).addStructuredInstance(obj);
- }
- //dt.FACTS_READ += facts.size();
-
- return mem;
- }
-
-
// Drools memory
private WorkingMemory session;
//// class specification
@@ -99,18 +60,47 @@
// instance list used to train
private HashMap<Class<?>,InstanceList> instances;
+ private InstanceList train_instances, test_instances;
+
// // instance list used to test
// private HashMap<Class<?>,InstanceList> test_instances;
+ private double trainRatio = Util.TRAINING_RATIO;
+ private double testRatio = Util.TESTING_RATIO;
private Memory() {
this.instances = new HashMap<Class<?>, InstanceList>();
}
+ public void setTrainRatio(double ratio) {
+ trainRatio = ratio;
+ }
+
+ public void setTestRatio(double ratio) {
+ testRatio = ratio;
+ }
+
+ public InstanceList getTrainSet() {
+ return train_instances;
+ }
+ public InstanceList getTestSet() {
+ return test_instances;
+ }
+
+ public void processTestSet() {
+ int split_idx = (int)(trainRatio * instances.get(this.clazzToClassify).getSize());
+ //int split_idx2 = split_idx + (int)(testRatio * instances.get(this.clazzToClassify).getSize());
+ int split_idx2 = instances.get(this.clazzToClassify).getSize();
+
+ train_instances = instances.get(this.clazzToClassify).subList(0, split_idx);
+ test_instances = instances.get(this.clazzToClassify).subList(split_idx, split_idx2);//class_instances.getSize());
+ return;
+ }
+
private void setClassToClassify(Class<?> clazz) {
this.clazzToClassify = clazz;
}
-
+ // target class instances
public InstanceList getClassInstances() {
return instances.get(this.clazzToClassify);
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Stats.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Stats.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/Stats.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -77,14 +77,6 @@
}
}
-// public String print4Latex() {
-//
-// double in = 100.0d *(double)getResult(Stats.INCORRECT)/(double) getTotal();
-// double co = 100.0d *(double)getResult(Stats.CORRECT)/(double) getTotal();
-// double un = 100.0d *(double)getResult(Stats.UNKNOWN)/(double) getTotal();
-// return "Builder" +"\t&\t"+ getResult(Stats.INCORRECT)+ "\t&\t"+ precision.format(in) + "\t&\t"+
-// getResult(Stats.CORRECT)+"\t&\t"+ precision.format(co)+ "\t&\t"+ getResult(Stats.UNKNOWN)+"\t&\t"+ precision.format(un)+ "\\\\";
-// }
public String print4Latex() {
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/StatsPrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/StatsPrinter.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/StatsPrinter.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -5,6 +5,8 @@
import java.io.IOException;
import java.io.Writer;
+import org.drools.learner.tools.Util;
+
public class StatsPrinter {
@@ -30,7 +32,33 @@
wr.close();
}
+ public static void print(Stats train, Stats test) {
+ StringBuffer sb = new StringBuffer();
+ sb.append("#"+ Stats.getErrors());
+ sb.append( "\n");
+ sb.append(train.print4Latex() +test.print4Latex()+"\\\\"+ "\n");
+ sb.append( "\n");
+
+ System.out.println(sb.toString());
+ }
+ public static void printLatexComment(String comment, String executionSignature, boolean append) {
+ StringBuffer sb = new StringBuffer();
+ sb.append("#"+ comment+"\n");
+ StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, append);
+ }
+ public static void printLatexLine(String executionSignature, boolean append) {
+ StringBuffer sb = new StringBuffer();
+ sb.append( "\n");
+ StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, append);
+ }
+
+ public static void printLatex(Stats train, Stats test, String executionSignature, boolean append) {
+ StringBuffer sb = new StringBuffer();
+ sb.append(train.print4Latex() +test.print4Latex()+"\\\\"+ "\n");
+ StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, append);
+ }
+
public static void print2file(StringBuffer sb, String fileSignature, boolean append) {
//String dataFileName = "src/main/rules/"+_packageNames+"/"+ fileName;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -7,6 +7,9 @@
import org.drools.learner.Memory;
import org.drools.learner.Stats;
import org.drools.learner.StatsPrinter;
+import org.drools.learner.builder.test.BoostedTester;
+import org.drools.learner.builder.test.SingleTreeTester;
+import org.drools.learner.builder.test.Tester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
@@ -14,14 +17,13 @@
/*
*
*/
-public class AdaBoostBuilder implements DecisionTreeBuilder{
+public class AdaBoostBuilder extends DecisionTreeBuilder{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(AdaBoostBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(AdaBoostBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private TreeAlgo algorithm = TreeAlgo.BOOST; // default bagging, TODO boosting
- private double trainRatio = Util.TRAINING_RATIO, testRatio = Util.TESTING_RATIO;
private static int FOREST_SIZE = Util.NUM_TREES;
private static final double TREE_SIZE_RATIO = 1.0d;
@@ -31,37 +33,27 @@
private ArrayList<Double> classifier_accuracy;
private DecisionTreeMerger merger;
+
- private DecisionTree best;
+// private BoostedTester tester;
+//
+// private ArrayList<Stats> train_evaluation, test_evaluation;
- private BoostedTester tester;
-
- private ArrayList<Stats> train_evaluation, test_evaluation;
-
public AdaBoostBuilder() {
//this.trainer = _trainer;
merger = new DecisionTreeMerger();
- train_evaluation = new ArrayList<Stats>(FOREST_SIZE);
- test_evaluation = new ArrayList<Stats>(FOREST_SIZE);
+// train_evaluation = new ArrayList<Stats>(FOREST_SIZE);
+// test_evaluation = new ArrayList<Stats>(FOREST_SIZE);
}
- public void setTrainRatio(double ratio) {
- trainRatio = ratio;
- }
- public void setTestRatio(double ratio) {
- testRatio = ratio;
- }
- public void build(Memory mem, Learner _trainer) {
-
- final InstanceList class_instances = mem.getClassInstances();
- _trainer.setInputData(class_instances);
-
-
- if (class_instances.getTargets().size()>1 ) {
+ public void internalBuild(SolutionSet sol_set, Learner _trainer) {
+ _trainer.setInputSpec(sol_set.getInputSpec());
+ if (sol_set.getTargets().size()>1) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
- if (slog.error() !=null)
- slog.error().log("There is more than 1 target candidates\n");
+ if (flog.error() !=null)
+ flog.error().log("There is more than 1 target candidates");
+
System.exit(0);
// TODO put the feature not supported exception || implement it
} else if (_trainer.getTargetDomain().getCategoryCount() >2) {
@@ -70,25 +62,12 @@
System.exit(0);
}
- int split_idx = (int)(trainRatio * class_instances.getSize());
- int split_idx2 = split_idx + (int)(testRatio * class_instances.getSize());
- InstanceList train_instances = class_instances.subList(0, split_idx);
- InstanceList test_instances = class_instances.subList(split_idx, split_idx2);//class_instances.getSize());
-
-
- int N = train_instances.getSize();
+ int N = sol_set.getTrainSet().getSize();
int NUM_DATA = (int)(TREE_SIZE_RATIO * N); // TREE_SIZE_RATIO = 1.0, all training data is used to train the trees again again
_trainer.setTrainingDataSizePerTree(NUM_DATA);
/* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
_trainer.setTrainingDataSize(NUM_DATA);
-
-// int N = class_instances.getSize();
-// int NUM_DATA = (int)(TREE_SIZE_RATIO * N); // TREE_SIZE_RATIO = 1.0, all data is used to train the trees again again
-// _trainer.setTrainingDataSizePerTree(NUM_DATA);
-//
-// /* all data fed to each tree, the same data?? */
-// _trainer.setTrainingDataSize(NUM_DATA); // TODO????
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
@@ -98,9 +77,9 @@
for (int index_i=0; index_i<NUM_DATA; index_i++) {
weights[index_i] = 1.0d/(double)NUM_DATA;
// class_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
- train_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
+ sol_set.getTrainSet().getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
if (slog.debug() != null)
- slog.debug().log(index_i+" new weight:"+train_instances.getInstance(index_i).getWeight()+ "\n");
+ slog.debug().log(index_i+" new weight:"+sol_set.getTrainSet().getInstance(index_i).getWeight()+ "\n");
}
int i = 0;
@@ -111,7 +90,7 @@
// else
// bag = Util.bag_wo_rep(NUM_DATA, N);
- InstanceList working_instances = train_instances;//class_instances; //.getInstances(bag);
+ InstanceList working_instances = sol_set.getTrainSet();//class_instances; //.getInstances(bag);
DecisionTree dt = _trainer.train_tree(working_instances);
dt.setID(i);
@@ -183,34 +162,208 @@
}
- train_evaluation.add(single_tester.test(train_instances));
- test_evaluation.add(single_tester.test(test_instances));
+
forest.add(dt);
// the DecisionTreeMerger will visit the decision tree and add the paths that have not been seen yet to the list
merger.add(dt);
+ // adding to the set of solutions
+ Tester t = getTester(dt);
+ Stats train = t.test(sol_set.getTrainSet());
+ Stats test = t.test(sol_set.getTestSet());
+ Solution one = new Solution(dt, sol_set.getTrainSet());
+ one.setTestList(sol_set.getTestSet());
+ one.setTrainStats(train);
+ one.setTestStats(test);
+ sol_set.addSolution(one);
+
if (slog.stat() !=null)
slog.stat().stat(".");
}
- tester = new BoostedTester(forest, getAccuracies());
- train_evaluation.add(tester.test(train_instances));
- test_evaluation.add(tester.test(test_instances));
+
+
+ Tester global_tester = getTester(forest, getAccuracies());
+ Stats train = global_tester.test(sol_set.getTrainSet());
+ Stats test = global_tester.test(sol_set.getTestSet());
+ sol_set.setGlobalTrainStats(train);
+ sol_set.setGlobalTestStats(test);
+
+ //System.exit(0);
// TODO how to compute a best tree from the forest
- int best_id = getMinTestId();
- best = merger.getBest();
- if (best == null)
- best = forest.get(best_id);
+ int best_id = sol_set.getMinTestId();
+// best = merger.getBest();
+// if (best == null)
+// best = forest.get(best_id);
- train_evaluation.add(train_evaluation.get(best_id));
- test_evaluation.add(test_evaluation.get(best_id));
+// best_solution
+ sol_set.setBestSolutionId(best_id);
- //_trainer.setBestTree(forest.get(0));
- //this.c45 = dt;
+ return;
+
}
+// public void build(Memory mem, Learner _trainer) {
+//
+// final InstanceList class_instances = mem.getClassInstances();
+// _trainer.setInputData(class_instances);
+//
+//
+// if (class_instances.getTargets().size()>1 ) {
+// //throw new FeatureNotSupported("There is more than 1 target candidates");
+// if (slog.error() !=null)
+// slog.error().log("There is more than 1 target candidates\n");
+// System.exit(0);
+// // TODO put the feature not supported exception || implement it
+// } else if (_trainer.getTargetDomain().getCategoryCount() >2) {
+// if (slog.error() !=null)
+// slog.error().log("The target domain is not binary!!!\n");
+// System.exit(0);
+// }
+//
+// int split_idx = (int)(trainRatio * class_instances.getSize());
+// int split_idx2 = split_idx + (int)(testRatio * class_instances.getSize());
+//
+// InstanceList train_instances = class_instances.subList(0, split_idx);
+// InstanceList test_instances = class_instances.subList(split_idx, split_idx2);//class_instances.getSize());
+//
+//
+// int N = train_instances.getSize();
+// int NUM_DATA = (int)(TREE_SIZE_RATIO * N); // TREE_SIZE_RATIO = 1.0, all training data is used to train the trees again again
+// _trainer.setTrainingDataSizePerTree(NUM_DATA);
+// /* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
+// _trainer.setTrainingDataSize(NUM_DATA);
+//
+//// int N = class_instances.getSize();
+//// int NUM_DATA = (int)(TREE_SIZE_RATIO * N); // TREE_SIZE_RATIO = 1.0, all data is used to train the trees again again
+//// _trainer.setTrainingDataSizePerTree(NUM_DATA);
+////
+//// /* all data fed to each tree, the same data?? */
+//// _trainer.setTrainingDataSize(NUM_DATA); // TODO????
+//
+//
+// forest = new ArrayList<DecisionTree> (FOREST_SIZE);
+// classifier_accuracy = new ArrayList<Double>(FOREST_SIZE);
+// // weight for each instance - the higher the weight, the more the instance influences the classifier learned.
+// double [] weights = new double [NUM_DATA];
+// for (int index_i=0; index_i<NUM_DATA; index_i++) {
+// weights[index_i] = 1.0d/(double)NUM_DATA;
+//// class_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
+// train_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
+// if (slog.debug() != null)
+// slog.debug().log(index_i+" new weight:"+train_instances.getInstance(index_i).getWeight()+ "\n");
+// }
+//
+// int i = 0;
+//// int[] bag;
+// while (i++ < FOREST_SIZE ) {
+//// if (WITH_REP)
+//// bag = Util.bag_w_rep(NUM_DATA, N);
+//// else
+//// bag = Util.bag_wo_rep(NUM_DATA, N);
+//
+// InstanceList working_instances = train_instances;//class_instances; //.getInstances(bag);
+// DecisionTree dt = _trainer.train_tree(working_instances);
+// dt.setID(i);
+//
+// double error = 0.0, sum_weight = 0.0;
+// SingleTreeTester single_tester= new SingleTreeTester(dt);
+// for (int index_i = 0; index_i < NUM_DATA; index_i++) {
+//// Integer result = t.test(class_instances.getInstance(index_i));
+// Integer result = single_tester.test(working_instances.getInstance(index_i));
+// sum_weight += weights[index_i];
+// if (result == Stats.INCORRECT) {
+// error += weights[index_i];
+// if (slog.debug() != null)
+// slog.debug().log("[e:"+error+" w:"+weights[index_i]+ "] ");
+// }
+// }
+//
+// error = error / sum_weight; // forgotton
+// if (error > 0.0f) {
+// double alpha = Util.ln( (1.0d-error)/error ) / 2.0d;
+//
+// if (error < 0.5d) {
+// // The classification accuracy of the weak classifier
+// classifier_accuracy.add(alpha);
+//
+// double norm_fact= 0.0d;
+// // Boosting the missclassified instances
+// for (int index_i = 0; index_i < NUM_DATA; index_i++) {
+//// Integer result = t.test(class_instances.getInstance(index_i));//TODO dont need to test two times
+//// Integer result = t.test(class_instances.getInstance(index_i));
+// Integer result = single_tester.test(working_instances.getInstance(index_i));
+// switch (result) {
+// case Stats.INCORRECT:
+// weights[index_i] = weights[index_i] * Util.exp(alpha);
+// break;
+// case Stats.CORRECT: // if it is correct do not update
+// //weights[index_i] = weights[index_i] * Util.exp(-1.0d * alpha);
+// break;
+// case Stats.UNKNOWN:
+// if (slog.error() !=null)
+// slog.error().log("Unknown situation bok\n");
+// System.exit(0);
+// break;
+// }
+// norm_fact += weights[index_i];
+// }
+// // Normalization of the weights
+// for (int index_i = 0; index_i < NUM_DATA; index_i++) {
+// weights[index_i] = weights[index_i] / norm_fact;
+// //class_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
+// working_instances.getInstance(index_i).setWeight(weights[index_i] * (double)NUM_DATA);
+// }
+// } else {
+// if (slog.debug() != null)
+// slog.debug().log("The error="+error+" alpha:"+alpha+ "\n");
+// if (slog.error() != null)
+// slog.error().log("error:"+error + " alpha will be negative and the weights of the training samples will be updated in the wrong direction"+"\n");
+// FOREST_SIZE = i-1;//ignore the current tree
+// break;
+// }
+//
+// } else {
+// if (slog.stat() != null) {
+// slog.stat().log("\n Boosting ends: ");
+// slog.stat().log("All instances classified correctly TERMINATE, forest size:"+i+ "\n");
+// }
+// // What to do here??
+// FOREST_SIZE = i;
+// classifier_accuracy.add(10.0); // TODO add a very big number
+//
+//
+// }
+// train_evaluation.add(single_tester.test(train_instances));
+// test_evaluation.add(single_tester.test(test_instances));
+//
+// forest.add(dt);
+// // the DecisionTreeMerger will visit the decision tree and add the paths that have not been seen yet to the list
+// merger.add(dt);
+//
+// if (slog.stat() !=null)
+// slog.stat().stat(".");
+//
+// }
+//
+// tester = new BoostedTester(forest, getAccuracies());
+// train_evaluation.add(tester.test(train_instances));
+// test_evaluation.add(tester.test(test_instances));
+// // TODO how to compute a best tree from the forest
+// int best_id = getMinTestId();
+// best = merger.getBest();
+// if (best == null)
+// best = forest.get(best_id);
+//
+// train_evaluation.add(train_evaluation.get(best_id));
+// test_evaluation.add(test_evaluation.get(best_id));
+//
+// //_trainer.setBestTree(forest.get(0));
+// //this.c45 = dt;
+// }
+
public ArrayList<DecisionTree> getTrees() {
return forest;
}
@@ -227,71 +380,50 @@
}
- public int getMinTestId() {
- double min = 1.0;
- int id = -1;
- for (int i=0; i< FOREST_SIZE; i++ ) {
- Stats test_s=test_evaluation.get(i);
- Stats train_s=train_evaluation.get(i);
- double test_error = ((double)test_s.getResult(Stats.INCORRECT)/(double)test_s.getTotal());
- double train_error = ((double)train_s.getResult(Stats.INCORRECT)/(double)train_s.getTotal());
- if (test_error < min) {
- min = test_error;
- id = i;
- } else if (test_error == min) {
- Stats old = train_evaluation.get(id);
- double train_old = ((double)old.getResult(Stats.INCORRECT)/(double)old.getTotal());
- if (train_error < train_old) {
- min = test_error;
- id = i;
- }
- }
-
- }
- return id;
-
+ public Tester getTester(ArrayList<DecisionTree> boosted_forest, ArrayList<Double> acc) {
+ return new BoostedTester(boosted_forest, acc);
}
- public DecisionTree getTree() {
- return best;
+ public Solution getBestSolution() {
+ return solutions.getBestSolution();
}
- public void printResults(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(train_evaluation.get(i).print2string() + "\n");
- }
- sb.append( "\n\n");
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(test_evaluation.get(i).print2string() + "\n");
- }
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE).print2string() + "\n");
- sb.append( "\n");
- sb.append(test_evaluation.get(FOREST_SIZE).print2string() + "\n");
-
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
- sb.append( "\n");
- sb.append(test_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, false);
- }
-
- public void printLatex(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(train_evaluation.get(i).print4Latex() +test_evaluation.get(i).print4Latex() + "\\\\"+"\n");
- }
-
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE).print4Latex() +test_evaluation.get(FOREST_SIZE).print4Latex() +"\\\\"+"\n");
-
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE+1).print4Latex() +test_evaluation.get(FOREST_SIZE+1).print4Latex()+"\\\\"+ "\n");
- sb.append( "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
- }
+// public void printResults(String executionSignature) {
+// StringBuffer sb = new StringBuffer();
+// sb.append("#"+ Stats.getErrors());
+// for (int i =0 ; i<FOREST_SIZE; i++) {
+// sb.append(train_evaluation.get(i).print2string() + "\n");
+// }
+// sb.append( "\n\n");
+// for (int i =0 ; i<FOREST_SIZE; i++) {
+// sb.append(test_evaluation.get(i).print2string() + "\n");
+// }
+// sb.append( "\n");
+// sb.append(train_evaluation.get(FOREST_SIZE).print2string() + "\n");
+// sb.append( "\n");
+// sb.append(test_evaluation.get(FOREST_SIZE).print2string() + "\n");
+//
+// sb.append( "\n");
+// sb.append(train_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
+// sb.append( "\n");
+// sb.append(test_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
+//
+// StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, false);
+// }
+//
+// public void printLatex(String executionSignature) {
+// StringBuffer sb = new StringBuffer();
+// sb.append("#"+ Stats.getErrors());
+// for (int i =0 ; i<FOREST_SIZE; i++) {
+// sb.append(train_evaluation.get(i).print4Latex() +test_evaluation.get(i).print4Latex() + "\\\\"+"\n");
+// }
+//
+// sb.append( "\n");
+// sb.append(train_evaluation.get(FOREST_SIZE).print4Latex() +test_evaluation.get(FOREST_SIZE).print4Latex() +"\\\\"+"\n");
+//
+// sb.append( "\n");
+// sb.append(train_evaluation.get(FOREST_SIZE+1).print4Latex() +test_evaluation.get(FOREST_SIZE+1).print4Latex()+"\\\\"+ "\n");
+// sb.append( "\n");
+//
+// StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
+// }
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -8,18 +8,19 @@
import org.drools.learner.InstanceList;
import org.drools.learner.Memory;
import org.drools.learner.Stats;
+import org.drools.learner.builder.test.SingleTreeTester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
-public class AdaBoostKBuilder implements DecisionTreeBuilder{
+public class AdaBoostKBuilder extends DecisionTreeBuilder{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(AdaBoostKBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(AdaBoostKBuilder.class, SimpleLogger.DEBUG);
private TreeAlgo algorithm = TreeAlgo.BOOST_K; // default bagging, TODO boosting
- private double trainRatio = Util.TRAINING_RATIO, testRatio = Util.TESTING_RATIO;
+// private double trainRatio = Util.TRAINING_RATIO, testRatio = Util.TESTING_RATIO;
private static int FOREST_SIZE = 10;
private static final double TREE_SIZE_RATIO = 0.9;
@@ -28,7 +29,7 @@
private ArrayList<DecisionTree> forest;
private ArrayList<Double> classifier_accuracy;
- private DecisionTree best;
+ private Solution best_solution;
//private Learner trainer;
private DecisionTreeMerger merger;
@@ -38,19 +39,18 @@
merger = new DecisionTreeMerger();
}
- public void setTrainRatio(double ratio) {
- trainRatio = ratio;
- }
- public void setTestRatio(double ratio) {
- testRatio = ratio;
- }
- public void build(Memory mem, Learner _trainer) {
+// public void setTrainRatio(double ratio) {
+// trainRatio = ratio;
+// }
+// public void setTestRatio(double ratio) {
+// testRatio = ratio;
+// }
+ public void internalBuild(SolutionSet sol, Learner _trainer) {
- final InstanceList class_instances = mem.getClassInstances();
- _trainer.setInputData(class_instances);
+ _trainer.setInputSpec(sol.getInputSpec());
- if (class_instances.getTargets().size()>1 ) {
+ if (sol.getInputSpec().getTargets().size()>1 ) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
if (flog.error() !=null)
flog.error().log("There is more than 1 target candidates");
@@ -61,7 +61,7 @@
flog.warn().log("The target domain is binary!!! Do u really need that one");
}
- int N = class_instances.getSize();
+ int N = sol.getTrainSet().getSize();
//_trainer.setTrainingDataSize(N); not only N data is fed.
int K = _trainer.getTargetDomain().getCategoryCount();
@@ -77,7 +77,7 @@
double[][] weight = new double[M][K];
for (int index_i=0; index_i<M; index_i++) {
for (int index_j=0; index_j<K; index_j++) {
- Instance inst_i = class_instances.getInstance(index_i);
+ Instance inst_i = sol.getTrainSet().getInstance(index_i);
Object instance_target = inst_i.getAttrValue(_trainer.getTargetDomain().getFReferenceName());
Object instance_target_category = _trainer.getTargetDomain().getCategoryOf(instance_target);
@@ -106,14 +106,14 @@
// b. Train h_t(x) by minimizing loss function
- InstanceList working_instances = class_instances.getInstances(bag);
+ InstanceList working_instances = sol.getTrainSet().getInstances(bag);
DecisionTree dt = _trainer.train_tree(working_instances);
dt.setID(i);
double error = 0.0;
SingleTreeTester t= new SingleTreeTester(dt);
for (int index_i = 0; index_i < M; index_i++) {
- Integer result = t.test(class_instances.getInstance(index_i));
+ Integer result = t.test(sol.getTrainSet().getInstance(index_i));
if (result == Stats.INCORRECT) {
//error += distribution.get(index_i);
@@ -130,7 +130,7 @@
double norm_fact= 0.0d;
// Update the weight matrix wij:
for (int index_i = 0; index_i < M; index_i++) {
- Integer result = t.test(class_instances.getInstance(index_i));//TODO dont need to test two times
+ Integer result = t.test(sol.getTrainSet().getInstance(index_i));//TODO dont need to test two times
switch (result) {
case Stats.INCORRECT:
//distribution.set(index_i, distribution.get(index_i) * Util.exp(alpha));
@@ -183,9 +183,9 @@
}
// TODO how to compute a best tree from the forest
//_trainer.setBestTree(forest.get(0));
- best = merger.getBest();
- if (best == null)
- best = forest.get(0);
+// best = merger.getBest();
+// if (best == null)
+// best = forest.get(0);
//this.c45 = dt;
}
@@ -205,7 +205,7 @@
}
- public DecisionTree getTree() {
- return best;
+ public Solution getBestSolution() {
+ return best_solution;
}
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -44,7 +44,7 @@
LeafNode classifiedNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
data_stats.get_winner_class() /*winner target category*/);
- classifiedNode.setRank( (double)data_stats.getSum()/
+ classifiedNode.setRank( data_stats.getSum()/
(double)this.getTrainingDataSize()/* total size of data fed to dt*/);
classifiedNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
classifiedNode.setNumClassification(data_stats.getSum()); //num of classified instances at the leaf node
@@ -60,7 +60,7 @@
Object winner = data_stats.get_winner_class(); /*winner target category*/
LeafNode noAttributeLeftNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
winner);
- noAttributeLeftNode.setRank((double)data_stats.getVoteFor(winner)/
+ noAttributeLeftNode.setRank(data_stats.getVoteFor(winner)/
(double)this.getTrainingDataSize() /* total size of data fed to dt*/);
noAttributeLeftNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner)); //num of classified instances at the leaf node
@@ -69,7 +69,7 @@
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
- dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/getTrainingDataSize());
+ dt.changeTrainError((data_stats.getSum() - data_stats.getVoteFor(winner))/(double)getTrainingDataSize());
return noAttributeLeftNode;
}
@@ -93,7 +93,7 @@
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
- dt.setTrainingError(dt.getTrainingError() + (data_stats.getSum()-data_stats.getVoteFor(winner))/getTrainingDataSize());
+ dt.changeTrainError((data_stats.getSum() - data_stats.getVoteFor(winner))/(double)getTrainingDataSize());
return majorityNode;
}
}
@@ -103,7 +103,7 @@
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
- currentNode.setRank((double)data_stats.getSum()/
+ currentNode.setRank(data_stats.getSum()/
(double)this.getTrainingDataSize() /* total size of data fed to trainer*/);
currentNode.setInfoMea(best_attr_eval.attribute_eval);
//what the highest represented class is and what proportion of items at that node actually are that class
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -3,21 +3,37 @@
import org.drools.learner.DecisionTree;
import org.drools.learner.Memory;
+import org.drools.learner.builder.test.SingleTreeTester;
+import org.drools.learner.builder.test.Tester;
-public interface DecisionTreeBuilder {
+public abstract class DecisionTreeBuilder {
//public static final int SINGLE = 1, BAG = 2, BOOST = 3;
public static enum TreeAlgo { SINGLE, BAG, BOOST, BOOST_K }
- void build(Memory wm, Learner trainer);
+ public SolutionSet solutions;
+ public int best_solution_id;
-// public Learner getLearner();
+ public SolutionSet build(Memory wm, Learner trainer) {
+ solutions = beforeBuild(wm);
+ internalBuild(solutions, trainer);
+ return solutions;
+ }
- public TreeAlgo getTreeAlgo();
+ protected SolutionSet beforeBuild(Memory wm) {
+ return new SolutionSet(wm);
+ }
- public DecisionTree getTree();
+ protected abstract void internalBuild(SolutionSet sol, Learner trainer);
- public void setTrainRatio(double ratio);
- public void setTestRatio(double ratio);
+ public abstract TreeAlgo getTreeAlgo();
+ public abstract Solution getBestSolution();
+
+
+ public Tester getTester(DecisionTree dt) {
+ return new SingleTreeTester(dt);
+ }
+
+
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -6,8 +6,14 @@
import org.drools.learner.DecisionTree;
import org.drools.learner.DecisionTreePruner;
import org.drools.learner.Memory;
+import org.drools.learner.Stats;
+import org.drools.learner.StatsPrinter;
import org.drools.learner.builder.Learner.DataType;
import org.drools.learner.builder.Learner.DomainAlgo;
+import org.drools.learner.builder.test.BoostedTester;
+import org.drools.learner.builder.test.ForestTester;
+import org.drools.learner.builder.test.SingleTreeTester;
+import org.drools.learner.builder.test.Tester;
import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
import org.drools.learner.eval.CrossValidation;
import org.drools.learner.eval.Entropy;
@@ -52,124 +58,83 @@
/* create the memory */
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+ mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
+ mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+ mem.processTestSet();
- SingleTreeTester tester = new SingleTreeTester(single_builder.getTree());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
- //Tester.test(c45, mem.getClassInstances());
+ //Ruler save_me_please = new Ruler()
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
- }
- public static DecisionTree createSingleC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45(wm, obj_class, new Entropy());
- }
-
- public static DecisionTree createSingleC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45(wm, obj_class, new GainRatio());
- }
-
- protected static DecisionTree createSingleC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
- C45Learner learner = new C45Learner(h);
+ SolutionSet product = single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+ Solution s1 = product.getSolutions().get(0);
+ Tester t = single_builder.getTester(s1.getTree());
+ StatsPrinter.printLatex(t.test(s1.getList()), t.test(s1.getTestList()), executionSignature, false);
+ //single_builder.printLatex(executionSignature);
+ //
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
+// DecisionTreePruner pruner = new DecisionTreePruner();
+// for (Solution sol: product.getSolutions())
+// pruner.prun_to_estimate(sol);
+// Solution s2 = pruner.getBestSolution();
+// Tester t2 = single_builder.getTester(pruner.getBestSolution().getTree());
+// StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, false);
-// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
-// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
-
- SingleTreeTester tester = new SingleTreeTester(single_builder.getTree());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
- //Tester.test(c45, mem.getClassInstances());
-
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+ single_builder.getBestSolution().getTree().setSignature(executionSignature);
+ return single_builder.getBestSolution().getTree();
}
- public static DecisionTree createBagC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBagC45( wm,obj_class, new Entropy());
+ /************************************************************************************************************************
+ * Single Tree Builder Algorithms with C4.5
+ * @param wm
+ * @param obj_class
+ * @return
+ * @throws FeatureNotSupported
+ */
+ public static DecisionTree createSingleC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createSingleC45(wm, obj_class, new Entropy(), criteria, null);
}
- public static DecisionTree createBagC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBagC45( wm,obj_class, new GainRatio());
- }
- protected static DecisionTree createBagC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
- C45Learner learner = new C45Learner(h);
- ForestBuilder forest = new ForestBuilder();
-
-// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
-// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- forest.build(mem, learner);
- //forest.clearForest(10);
-
- ForestTester tester = new ForestTester(forest.getTrees());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
- //forest.test(mem.getClassInstances(), Util.DRL_DIRECTORY+executionSignature);
-
- //Tester bla => test(c45, mem.getClassInstances());
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+ public static DecisionTree createSingleC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createSingleC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createBoostedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45(wm, obj_class, new Entropy());
+ public static DecisionTree createSingleC45E_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createSingleC45(wm, obj_class, new Entropy(), criteria, null);
}
- public static DecisionTree createBoostedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45(wm, obj_class, new GainRatio());
- }
- public static DecisionTree createBoostedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- AdaBoostBuilder forest = new AdaBoostBuilder();
-
-// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
-// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- forest.build(mem, learner);
- //forest.clearForest(10);
-
- BoostedTester tester = new BoostedTester(forest.getTrees(), forest.getAccuracies());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
- //forest.test(mem.getClassInstances(), Util.DRL_DIRECTORY+executionSignature);
-
- //Tester bla => test(c45, mem.getClassInstances());
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+ public static DecisionTree createSingleC45G_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createSingleC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createSingleC45E_Stopped(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_Stop(wm, obj_class, new Entropy());
+ public static DecisionTree createSingleC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.5));
+ return createSingleC45(wm, obj_class, new Entropy(), criteria ,pruner);
}
- public static DecisionTree createSingleC45G_Stopped(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45(wm, obj_class, new GainRatio());
+ public static DecisionTree createSingleC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ criteria.add(new EstimatedNodeSize(0.5));
+ return createSingleC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
- protected static DecisionTree createSingleC45_Stop(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+ protected static DecisionTree createSingleC45(WorkingMemory wm, Class<? extends Object> obj_class,
+ Heuristic h,
+ ArrayList<StoppingCriterion> criteria,
+ DecisionTreePruner pruner) throws FeatureNotSupported {
DataType data = Learner.DEFAULT_DATA;
-
C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
-
SingleTreeBuilder single_builder = new SingleTreeBuilder();
String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
@@ -177,382 +142,254 @@
/* create the memory */
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+ mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
+ mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+ mem.processTestSet();
- SingleTreeTester tester = new SingleTreeTester(single_builder.getTree());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
+ for (StoppingCriterion sc: criteria) {
+ if (sc instanceof MaximumDepth) {
+ int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
+ ((MaximumDepth)sc).setDepth(max_depth);
+ }
+ learner.addStoppingCriteria(sc);
+ }
- //Tester.test(c45, mem.getClassInstances());
- tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
+ SolutionSet product = single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+ Solution s1 = product.getSolutions().get(0);
+ StatsPrinter.printLatexComment("ORIGINAL TREE", executionSignature, false);
+ StatsPrinter.printLatexComment(Stats.getErrors(), executionSignature, true);
+ StatsPrinter.printLatex(s1.getTrainStats(), s1.getTestStats(), executionSignature, true);
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+ Tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
+
+ if (pruner != null) {
+ for (Solution sol: product.getSolutions())
+ pruner.prun_to_estimate(sol);
+ Solution s2 = pruner.getBestSolution();
+ Tester t2 = single_builder.getTester(pruner.getBestSolution().getTree());
+ StatsPrinter.printLatexComment("Pruned TREE", executionSignature, true);
+ StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, true);
+ }
+ single_builder.getBestSolution().getTree().setSignature(executionSignature);
+ return single_builder.getBestSolution().getTree();
}
- public static DecisionTree createSingleC45E_StoppedTest(WorkingMemory wm, WorkingMemory wm_test, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_StopTest(wm, wm_test, obj_class, new Entropy());
- }
- protected static DecisionTree createSingleC45_StopTest(WorkingMemory wm, WorkingMemory wm_test, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
-
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
-
- SingleTreeTester tester = new SingleTreeTester(single_builder.getTree());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, false);
-
- //Tester.test(c45, mem.getClassInstances());
- tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- Memory test_mem = Memory.createTestFromWorkingMemory(mem, wm_test, obj_class, learner.getDomainAlgo(), data);
- tester.printStats(tester.test(test_mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature, true);
-
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+
+ /************************************************************************************************************************
+ * Bagging Algorithms
+ * @param wm
+ * @param obj_class
+ * @return
+ * @throws FeatureNotSupported
+ */
+ public static DecisionTree createBagC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBagC45(wm, obj_class, new Entropy(), criteria, null);
}
- public static DecisionTree createSingleC45E_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_Test(wm, obj_class, new Entropy());
+ public static DecisionTree createBagC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBagC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createSingleC45G_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_Test(wm, obj_class, new GainRatio());
- }
- protected static DecisionTree createSingleC45_Test(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- //learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
-
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
- single_builder.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- single_builder.setTestRatio(Util.DEFAULT_TESTING_RATIO);
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- //learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
- single_builder.printResults(executionSignature);
- single_builder.printLatex(executionSignature);
-
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+ public static DecisionTree createBagC45E_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createBagC45(wm, obj_class, new Entropy(), criteria, null);
}
-
- public static DecisionTree createSingleC45E_StoppedTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_StopTest(wm, obj_class, new Entropy());
+ public static DecisionTree createBagC45G_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createBagC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createSingleC45G_StoppedTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSingleC45_StopTest(wm, obj_class, new GainRatio());
- }
- protected static DecisionTree createSingleC45_StopTest(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
- learner.addStoppingCriteria(new ImpurityDecrease());
-
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
- single_builder.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- single_builder.setTestRatio(Util.DEFAULT_TESTING_RATIO);
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.85 );
- learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
- single_builder.printResults(executionSignature);
- single_builder.printLatex(executionSignature);
-
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+ public static DecisionTree createBagC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBagC45(wm, obj_class, new Entropy(), criteria ,pruner);
}
- public static DecisionTree createBaggC45E_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBaggC45_Test(wm, obj_class, new Entropy());
+ public static DecisionTree createBagC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBagC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
- public static DecisionTree createBaggC45G_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBaggC45_Test(wm, obj_class, new GainRatio());
- }
- protected static DecisionTree createBaggC45_Test(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+ protected static DecisionTree createBagC45(WorkingMemory wm, Class<? extends Object> obj_class,
+ Heuristic h,
+ ArrayList<StoppingCriterion> criteria,
+ DecisionTreePruner pruner) throws FeatureNotSupported {
DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- //learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
-
+ C45Learner learner = new C45Learner(h);
ForestBuilder forest = new ForestBuilder();
- forest.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- forest.setTestRatio(Util.DEFAULT_TESTING_RATIO);
String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
/* create the memory */
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+ mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
+ mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+ mem.processTestSet();
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- //learner.addStoppingCriteria(new MaximumDepth(max_depth));
+ for (StoppingCriterion sc: criteria) {
+ if (sc instanceof MaximumDepth) {
+ int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
+ ((MaximumDepth)sc).setDepth(max_depth);
+ }
+ learner.addStoppingCriteria(sc);
+ }
- forest.build(mem, learner);
- forest.printResults(executionSignature);
- forest.printLatex(executionSignature);
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
+ SolutionSet product = forest.build(mem, learner);//obj_class, target_attr, working_attr
+ StatsPrinter.printLatexComment("Builder errors", executionSignature, false);
+ StatsPrinter.printLatexComment(Stats.getErrors(), executionSignature, true);
+ StatsPrinter.printLatex(product.getGlobalTrainStats(), product.getGlobalTestStats(), executionSignature, true);
+ StatsPrinter.printLatexComment("Each Original Tree", executionSignature, true);
+ for (Solution s: product.getSolutions()) {
+ StatsPrinter.printLatex(s.getTrainStats(), s.getTestStats(), executionSignature, true);
+ }
+ Solution best_s = product.getBestSolution();
+ StatsPrinter.printLatexComment("Best Original Tree", executionSignature, true);
+ StatsPrinter.printLatex(best_s.getTrainStats(), best_s.getTestStats(), executionSignature, true);
+ Tester t = forest.getTester(best_s.getTree());
+ StatsPrinter.printLatexComment("Best Original Tree(Global)", executionSignature, true);
+ StatsPrinter.printLatex(t.test(product.getTrainSet()), t.test(product.getTestSet()), executionSignature, true);
+
+ Tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+
+ if (pruner != null) {
+ for (Solution sol: product.getSolutions())
+ pruner.prun_to_estimate(sol);
+
+ Solution s2 = pruner.getBestSolution();
+ Tester t2 = forest.getTester(s2.getTree());
+ StatsPrinter.printLatexComment("Best Pruned Tree", executionSignature, true);
+ StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, true);
+ Tester t2_global = forest.getTester(s2.getTree());
+ StatsPrinter.printLatexComment("Best Original Tree(Global)", executionSignature, true);
+ StatsPrinter.printLatex(t2_global.test(product.getTrainSet()), t2_global.test(product.getTestSet()), executionSignature, true);
+ }
+
+ forest.getBestSolution().getTree().setSignature(executionSignature);
+ return forest.getBestSolution().getTree();
}
- public static DecisionTree createBaggC45E_StoppedTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBaggC45_StopTest(wm, obj_class, new Entropy());
- }
- public static DecisionTree createBaggC45G_StoppedTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBaggC45_StopTest(wm, obj_class, new GainRatio());
- }
- protected static DecisionTree createBaggC45_StopTest(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
- learner.addStoppingCriteria(new ImpurityDecrease());
-
- ForestBuilder forest = new ForestBuilder();
- forest.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- forest.setTestRatio(Util.DEFAULT_TESTING_RATIO);
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- forest.build(mem, learner);
- forest.printResults(executionSignature);
- forest.printLatex(executionSignature);
-
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+
+ /************************************************************************************************************************
+ * Boosting Algorithms
+ * @param wm
+ * @param obj_class
+ * @return
+ * @throws FeatureNotSupported
+ */
+ public static DecisionTree createBoostC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBoostC45(wm, obj_class, new Entropy(), criteria, null);
}
- public static DecisionTree createBoostedC45E_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45_Test(wm, obj_class, new Entropy());
+ public static DecisionTree createBoostC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBoostC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createBoostedC45G_Test(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45_Test(wm, obj_class, new GainRatio());
- }
- public static DecisionTree createBoostedC45_Test(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- //learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
- AdaBoostBuilder forest = new AdaBoostBuilder();
- forest.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- forest.setTestRatio(Util.DEFAULT_TESTING_RATIO);
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- //learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- forest.build(mem, learner);
-
- forest.printResults(executionSignature);
- forest.printLatex(executionSignature);
-
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- //Tester bla => test(c45, mem.getClassInstances());
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+ public static DecisionTree createBoostC45E_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createBoostC45(wm, obj_class, new Entropy(), criteria, null);
}
- public static DecisionTree createBoostedC45E_StopTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45_StopTest(wm, obj_class, new Entropy());
+ public static DecisionTree createBoostC45G_Stop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(3);
+ criteria.add(new EstimatedNodeSize(0.5));
+ criteria.add(new ImpurityDecrease());
+ criteria.add(new MaximumDepth(50));
+ return createBoostC45(wm, obj_class, new GainRatio(), criteria, null);
}
- public static DecisionTree createBoostedC45G_StopTest(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createBoostedC45_StopTest(wm, obj_class, new GainRatio());
- }
- public static DecisionTree createBoostedC45_StopTest(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.5));
- learner.addStoppingCriteria(new ImpurityDecrease());
-
- AdaBoostBuilder forest = new AdaBoostBuilder();
- forest.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- forest.setTestRatio(Util.DEFAULT_TESTING_RATIO);
-
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), forest.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
- learner.addStoppingCriteria(new MaximumDepth(max_depth));
-
- forest.build(mem, learner);
-
- forest.printResults(executionSignature);
- forest.printLatex(executionSignature);
-
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
-
- //Tester bla => test(c45, mem.getClassInstances());
- forest.getTree().setSignature(executionSignature);
- return forest.getTree();
+ public static DecisionTree createBoostC45E_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBoostC45(wm, obj_class, new Entropy(), criteria ,pruner);
}
-// public static DecisionTree createSingleCrossPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-//
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
-// }
-// public static DecisionTree createSingleCrossPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
-// }
-//
-// public static DecisionTree createSingleTestPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-//
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new Entropy(), new TestSample(0.2d), stopping_criteria);
-// }
-// public static DecisionTree createSingleTestPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new TestSample(0.2), stopping_criteria);
-// }
-//
-// public static DecisionTree createSingleCVPrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-//
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// //stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
-// }
-// public static DecisionTree createSingleCVPrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// //stopping_criteria.add(new EstimatedNodeSize(0.05));
-// return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
-// }
-
- public static DecisionTree createSingleTestPrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// //stopping_criteria.add(new EstimatedNodeSize(0.05));
- return createSinglePrunnedC45(wm, obj_class, new Entropy(), new TestSample(0.2d));
+ public static DecisionTree createBoostC45G_PrunStop(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ DecisionTreePruner pruner = new DecisionTreePruner();
+ ArrayList<StoppingCriterion> criteria = new ArrayList<StoppingCriterion>(1);
+ return createBoostC45(wm, obj_class, new GainRatio(), criteria ,pruner);
}
- public static DecisionTree createSingleTestPrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-// ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
-// stopping_criteria.add(new EstimatedNodeSize(0.05));
- return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new TestSample(0.2d));
- }
- protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h, ErrorEstimate validater) throws FeatureNotSupported {
+ protected static DecisionTree createBoostC45(WorkingMemory wm, Class<? extends Object> obj_class,
+ Heuristic h,
+ ArrayList<StoppingCriterion> criteria,
+ DecisionTreePruner pruner) throws FeatureNotSupported {
DataType data = Learner.DEFAULT_DATA;
-
- C45Learner learner = new C45Learner(h);
- learner.addStoppingCriteria(new EstimatedNodeSize(0.05));
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
- single_builder.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
- single_builder.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+ C45Learner learner = new C45Learner(h);
+ AdaBoostBuilder boosted_forest = new AdaBoostBuilder();
-// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
-// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+ String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), boosted_forest.getTreeAlgo());
String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
/* create the memory */
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
-
-// validater.validate(learner, mem.getClassInstances());
-//
-// DecisionTreePruner pruner = new DecisionTreePruner(validater);
-//
-// DecisionTree dt = pruner.prun_to_estimate();
+ mem.setTrainRatio(Util.DEFAULT_TRAINING_RATIO);
+ mem.setTestRatio(Util.DEFAULT_TESTING_RATIO);
+ mem.processTestSet();
- // you should be able to get the pruned tree
- // prun.getMinimumCostTree()
- // prun.getOptimumCostTree()
+ for (StoppingCriterion sc: criteria) {
+ if (sc instanceof MaximumDepth) {
+ int max_depth = (int)((mem.getClassInstances().getSchema().getAttrNames().size() - 1)*0.70 );
+ ((MaximumDepth)sc).setDepth(max_depth);
+ }
+ learner.addStoppingCriteria(sc);
+ }
- single_builder.printResults(executionSignature);
- single_builder.printLatex(executionSignature);
- single_builder.printLatex2(executionSignature);
+ SolutionSet product = boosted_forest.build(mem, learner);//obj_class, target_attr, working_attr
+ StatsPrinter.printLatexComment("Builder errors", executionSignature, false);
+ StatsPrinter.printLatexComment(Stats.getErrors(), executionSignature, true);
+ StatsPrinter.printLatex(product.getGlobalTrainStats(), product.getGlobalTestStats(), executionSignature, true);
+ StatsPrinter.printLatexComment("Each Original Tree", executionSignature, true);
+ for (Solution s: product.getSolutions()) {
+ StatsPrinter.printLatex(s.getTrainStats(), s.getTestStats(), executionSignature, true);
+ }
+ Solution best_s = product.getBestSolution();
+ StatsPrinter.printLatexComment("Best Original Tree", executionSignature, true);
+ StatsPrinter.printLatex(best_s.getTrainStats(), best_s.getTestStats(), executionSignature, true);
+
+ Tester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
- SingleTreeTester.printStopping(learner.getStoppingCriteria(), Util.DRL_DIRECTORY +executionSignature);
- /* Once Talpha is found the tree that is finally suggested for use is that
- * which minimises the cost-complexity using and all the data use the pruner to prun the tree
- */
- //pruner.prun_tree(single_builder.getTree());
-
-
- // test the tree again
-
-
- //Tester.test(c45, mem.getClassInstances());
-
- single_builder.getTree().setSignature(executionSignature);
- return single_builder.getTree();
+ if (pruner != null) {
+ for (Solution sol: product.getSolutions())
+ pruner.prun_to_estimate(sol);
+
+ Solution s2 = pruner.getBestSolution();
+ Tester t2 = boosted_forest.getTester(pruner.getBestSolution().getTree());
+ StatsPrinter.printLatexComment("Best Pruned Tree", executionSignature, true);
+ StatsPrinter.printLatex(t2.test(s2.getList()), t2.test(s2.getTestList()), executionSignature, true);
+
+ }
+
+ boosted_forest.getBestSolution().getTree().setSignature(executionSignature);
+ return boosted_forest.getBestSolution().getTree();
}
+
+
public static String getSignature(Class<? extends Object> obj_class, String fileName, String suffices) {
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -7,18 +7,22 @@
import org.drools.learner.Memory;
import org.drools.learner.Stats;
import org.drools.learner.StatsPrinter;
+import org.drools.learner.builder.test.BoostedTester;
+import org.drools.learner.builder.test.ForestTester;
+import org.drools.learner.builder.test.SingleTreeTester;
+import org.drools.learner.builder.test.Tester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
-public class ForestBuilder implements DecisionTreeBuilder{
+public class ForestBuilder extends DecisionTreeBuilder{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private TreeAlgo algorithm = TreeAlgo.BAG; // default bagging, TODO boosting
- private double trainRatio = Util.TRAINING_RATIO, testRatio = Util.TESTING_RATIO;
+// private double trainRatio = Util.TRAINING_RATIO, testRatio = Util.TESTING_RATIO;
private static final int FOREST_SIZE = Util.NUM_TREES;
private static final double TREE_SIZE_RATIO = 0.9;
@@ -26,61 +30,42 @@
private ArrayList<DecisionTree> forest;
//private Learner trainer;
+
- private DecisionTree best;
-
private DecisionTreeMerger merger;
- private ForestTester tester;
+// private ForestTester tester;
// private SingleTreeTester single_tester;
- private ArrayList<Stats> train_evaluation, test_evaluation;
+// private ArrayList<Stats> train_evaluation, test_evaluation;
public ForestBuilder() {
//this.trainer = _trainer;
merger = new DecisionTreeMerger();
- train_evaluation = new ArrayList<Stats>(FOREST_SIZE);
- test_evaluation = new ArrayList<Stats>(FOREST_SIZE);
+// train_evaluation = new ArrayList<Stats>(FOREST_SIZE);
+// test_evaluation = new ArrayList<Stats>(FOREST_SIZE);
}
- public void setTrainRatio(double ratio) {
- trainRatio = ratio;
- }
- public void setTestRatio(double ratio) {
- testRatio = ratio;
- }
- public void build(Memory mem, Learner _trainer) {
-
- final InstanceList class_instances = mem.getClassInstances();
- _trainer.setInputData(class_instances);
-
- if (class_instances.getTargets().size()>1) {
+ public void internalBuild(SolutionSet sol_set, Learner _trainer) {
+ _trainer.setInputSpec(sol_set.getInputSpec());
+ if (sol_set.getTargets().size()>1) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
if (flog.error() !=null)
flog.error().log("There is more than 1 target candidates");
+
System.exit(0);
// TODO put the feature not supported exception || implement it
+ } else if (_trainer.getTargetDomain().getCategoryCount() >2) {
+ if (slog.error() !=null)
+ slog.error().log("The target domain is not binary!!!\n");
+ System.exit(0);
}
-
- int split_idx = (int)(trainRatio * class_instances.getSize());
- int split_idx2 = split_idx + (int)(testRatio * class_instances.getSize());
- InstanceList train_instances = class_instances.subList(0, split_idx);
- InstanceList test_instances = class_instances.subList(split_idx, split_idx2);//class_instances.getSize());
-
- int N = train_instances.getSize();
+ int N = sol_set.getTrainSet().getSize();
int tree_capacity = (int)(TREE_SIZE_RATIO * N);
_trainer.setTrainingDataSizePerTree(tree_capacity);
/* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
- _trainer.setTrainingDataSize(tree_capacity * FOREST_SIZE);
-
-// int N = class_instances.getSize();
-// // _trainer.setTrainingDataSize(N); => wrong
-// int tree_capacity = (int)(TREE_SIZE_RATIO * N);
-// _trainer.setTrainingDataSizePerTree(tree_capacity);
-//
-// /* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
-// _trainer.setTrainingDataSize(tree_capacity * FOREST_SIZE);
+ _trainer.setTrainingDataSize(tree_capacity * FOREST_SIZE);
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
@@ -93,7 +78,7 @@
bag = Util.bag_wo_rep(tree_capacity, N);
//InstanceList working_instances = class_instances.getInstances(bag);
- InstanceList working_instances = train_instances.getInstances(bag);
+ InstanceList working_instances = sol_set.getTrainSet().getInstances(bag);
DecisionTree dt = _trainer.instantiate_tree();
if (slog.debug() != null)
@@ -105,33 +90,52 @@
forest.add(dt);
// the DecisionTreeMerger will visit the decision tree and add the paths that have not been seen yet to the list
merger.add(dt);
+
if (slog.stat() !=null)
slog.stat().stat(".");
- SingleTreeTester single_tester = new SingleTreeTester(dt);
+// SingleTreeTester single_tester = new SingleTreeTester(dt);
+// train_evaluation.add(single_tester.test(sol_set.getTrainSet()));
+// test_evaluation.add(single_tester.test(sol_set.getTestSet()));
- train_evaluation.add(single_tester.test(train_instances));
- test_evaluation.add(single_tester.test(test_instances));
+ // adding to the set of solutions
+ Tester t = getTester(dt);
+ Stats train = t.test(working_instances);
+ Stats test = t.test(sol_set.getTestSet());
+ Solution one = new Solution(dt, working_instances);
+ one.setTestList(sol_set.getTestSet());
+ one.setTrainStats(train);
+ one.setTestStats(test);
+ sol_set.addSolution(one);
}
- tester = new ForestTester(forest);
- train_evaluation.add(tester.test(train_instances));
- test_evaluation.add(tester.test(test_instances));
+// tester = new ForestTester(forest);
+// train_evaluation.add(tester.test(sol_set.getTrainSet()));
+// test_evaluation.add(tester.test(sol_set.getTestSet()));
+ Tester global_tester = getTester(forest);
+ Stats train = global_tester.test(sol_set.getTrainSet());
+ Stats test = global_tester.test(sol_set.getTestSet());
+ sol_set.setGlobalTrainStats(train);
+ sol_set.setGlobalTestStats(test);
+
//System.exit(0);
// TODO how to compute a best tree from the forest
- int best_id = getMinTestId();
- best = merger.getBest();
- if (best == null)
- best = forest.get(best_id);
+ int best_id = sol_set.getMinTestId();
+ sol_set.setBestSolutionId(best_id);
- train_evaluation.add(train_evaluation.get(best_id));
- test_evaluation.add(test_evaluation.get(best_id));
- //_trainer.setBestTree(best);// forest.get(0));
- //this.c45 = dt;
+
+ return;
+
}
+
+
+ public Tester getTester(ArrayList<DecisionTree> _forest) {
+ return new ForestTester(_forest);
+ }
+
public TreeAlgo getTreeAlgo() {
return algorithm; //TreeAlgo.BAG; // default
@@ -141,74 +145,10 @@
return forest;
}
- public DecisionTree getTree() {
- return best;
+ public Solution getBestSolution() {
+ return solutions.getBestSolution();
}
-
- public int getMinTestId() {
- double min = 1.0;
- int id = -1;
- for (int i=0; i< FOREST_SIZE; i++ ) {
- Stats test_s=test_evaluation.get(i);
- Stats train_s=train_evaluation.get(i);
- double test_error = ((double)test_s.getResult(Stats.INCORRECT)/(double)test_s.getTotal());
- double train_error = ((double)train_s.getResult(Stats.INCORRECT)/(double)train_s.getTotal());
- if (test_error < min) {
- min = test_error;
- id = i;
- } else if (test_error == min) {
- Stats old = train_evaluation.get(id);
- double train_old = ((double)old.getResult(Stats.INCORRECT)/(double)old.getTotal());
- if (train_error < train_old) {
- min = test_error;
- id = i;
- }
- }
-
- }
- return id;
-
- }
-
- public void printResults(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(train_evaluation.get(i).print2string() + "\n");
- }
- sb.append( "\n\n");
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(test_evaluation.get(i).print2string() + "\n");
- }
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE).print2string() + "\n");
- sb.append( "\n");
- sb.append(test_evaluation.get(FOREST_SIZE).print2string() + "\n");
-
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
- sb.append( "\n");
- sb.append(test_evaluation.get(FOREST_SIZE+1).print2string() + "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, false);
- }
-
- public void printLatex(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- for (int i =0 ; i<FOREST_SIZE; i++) {
- sb.append(train_evaluation.get(i).print4Latex() +test_evaluation.get(i).print4Latex() + "\\\\"+"\n");
- }
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE).print4Latex() +test_evaluation.get(FOREST_SIZE).print4Latex() +"\\\\"+"\n");
-
- sb.append( "\n");
- sb.append(train_evaluation.get(FOREST_SIZE+1).print4Latex() +test_evaluation.get(FOREST_SIZE+1).print4Latex()+"\\\\"+ "\n");
- sb.append( "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
- }
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -71,6 +71,14 @@
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
+ currentNode.setRank((double)data_stats.getSum()/
+ (double)this.getTrainingDataSize() /* total size of data fed to trainer*/);
+ //currentNode.setInfoMea(best_attr_eval.attribute_eval);
+ //what the highest represented class is and what proportion of items at that node actually are that class
+ currentNode.setLabel(data_stats.get_winner_class());
+ currentNode.setNumLabeled(data_stats.getSupportersFor(data_stats.get_winner_class()).size());
+
+
Hashtable<Object, InstDistribution> filtered_stats = data_stats.splitFromCategorical(node_domain, null);
dt.FACTS_READ += data_stats.getSum();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -44,6 +44,7 @@
this.data_size_per_tree = 0;
criteria = new ArrayList<StoppingCriterion>(4);
+ missclassified_data = new HashSet<Instance>();
}
@@ -123,8 +124,6 @@
public void setTrainingDataSizePerTree(int num) {
this.data_size_per_tree = num;
-
- missclassified_data = new HashSet<Instance>();
}
public int getTrainingDataSizePerTree() {
@@ -150,13 +149,13 @@
this.algorithm = type;
}
- public void setInputData(InstanceList class_instances) {
+ public void setInputSpec(InstanceList class_instances) {
this.input_data = class_instances;
}
- public InstanceList getInputData() {
- return input_data;
- }
+// public InstanceList getInputData() {
+// return input_data;
+// }
// must be deleted, goes to builder
// public void setBestTree(DecisionTree dt) {
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -1,58 +1,40 @@
package org.drools.learner.builder;
-import java.util.ArrayList;
import org.drools.learner.DecisionTree;
-import org.drools.learner.DecisionTreePruner;
-import org.drools.learner.InstanceList;
-import org.drools.learner.Memory;
import org.drools.learner.Stats;
import org.drools.learner.StatsPrinter;
+import org.drools.learner.builder.test.SingleTreeTester;
+import org.drools.learner.builder.test.Tester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
-public class SingleTreeBuilder implements DecisionTreeBuilder{
+public class SingleTreeBuilder extends DecisionTreeBuilder{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(SingleTreeBuilder.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(SingleTreeBuilder.class, SimpleLogger.DEFAULT_LEVEL);
- private boolean prune = false;
+
private TreeAlgo algorithm = TreeAlgo.SINGLE; // default bagging, TODO boosting
- private DecisionTree one_tree;
-
- private Stats train_evaluation, test_evaluation;
- private Stats train_evaluation2, test_evaluation2;
- private double trainRatio = Util.TRAINING_RATIO;
- private double testRatio = Util.TESTING_RATIO;
+// private Stats train_evaluation, test_evaluation;
+// private Stats train_evaluation2, test_evaluation2;
+// private SingleTreeTester tester;
- private SingleTreeTester tester;
-
public SingleTreeBuilder() {//Learner _trainer) {
- //this.trainer = _trainer;
- //dom_type = trainer.getDomainType();
}
- public void setTrainRatio(double ratio) {
- trainRatio = ratio;
- }
-
- public void setTestRatio(double ratio) {
- testRatio = ratio;
- }
-
/*
* the memory has the information
* the instances: the objects which the decision tree will work on
* the schema: the definition of the object instance
* (Class<?>) klass, String targetField, List<String> workingAttributes
*/
- public void build(Memory mem, Learner _trainer) {
- final InstanceList class_instances = mem.getClassInstances();
- _trainer.setInputData(class_instances);
- if (class_instances.getTargets().size()>1) {
+ public void internalBuild(SolutionSet sol, Learner _trainer) {
+
+ if (sol.getTargets().size()>1) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
if (flog.error() !=null)
flog.error().log("There is more than 1 target candidates");
@@ -60,87 +42,60 @@
System.exit(0);
// TODO put the feature not supported exception || implement it
}
- int split_idx = (int)(trainRatio * class_instances.getSize());
- int split_idx2 = split_idx + (int)(testRatio * class_instances.getSize());
- InstanceList train_instances = class_instances.subList(0, split_idx);
- InstanceList test_instances = class_instances.subList(split_idx, split_idx2);//class_instances.getSize());
-
- _trainer.setTrainingDataSize(train_instances.getSize());
- _trainer.setTrainingDataSizePerTree(train_instances.getSize());
-
- one_tree = _trainer.instantiate_tree();
+ _trainer.setInputSpec(sol.getInputSpec());
+ _trainer.setTrainingDataSize(sol.getTrainSet().getSize());
+ DecisionTree one_tree = _trainer.instantiate_tree();
if (slog.debug() != null)
slog.debug().log("\n"+"Training a tree"+"\n");
- _trainer.train_tree(one_tree, train_instances);
- one_tree.setTrainingDataSize(train_instances.getSize());
- one_tree.setTestingDataSize(test_instances.getSize());
- one_tree.setTrain(train_instances);
- one_tree.setTest(test_instances);
-
- tester = new SingleTreeTester(one_tree);
+ _trainer.train_tree(one_tree, sol.getTrainSet());
+ one_tree.setID(0);
- train_evaluation = tester.test(train_instances);
- test_evaluation = tester.test(test_instances);
- one_tree.setValidationError(Util.division(test_evaluation.getResult(Stats.INCORRECT), test_instances.getSize()));
- one_tree.setTrainingError(Util.division(train_evaluation.getResult(Stats.INCORRECT), train_instances.getSize()));
+ Tester t = getTester(one_tree);
+ Stats train = t.test(sol.getTrainSet());
+ Stats test = t.test(sol.getTestSet());
+ Solution best = new Solution(one_tree, sol.getTrainSet());
+ best.setTestList(sol.getTestSet());
+ best.setTrainStats(train);
+ best.setTestStats(test);
+ sol.addSolution(best);
- //if (prunner != null) {
- if (prune) {
- DecisionTreePruner pruner = new DecisionTreePruner();
- ArrayList<DecisionTree> dts = new ArrayList<DecisionTree>(1);
- dts.add(one_tree);
- pruner.prun_to_estimate(dts);
-// SingleTreeTester tester2 = new SingleTreeTester(one_tree);
-
- train_evaluation2 = tester.test(train_instances);
- test_evaluation2 = tester.test(test_instances);
-
- }
+ return;
-
- // must be deleted, goes to builder
- //_trainer.setBestTree(one_tree);
}
- public Stats getTrainStats() {
- return train_evaluation;
- }
- public Stats getTestStats() {
- return test_evaluation;
- }
-
public TreeAlgo getTreeAlgo() {
return this.algorithm; // default
}
- public DecisionTree getTree() {
- return one_tree;
+ public Solution getBestSolution() {
+ return solutions.getBestSolution();
}
- public void printResults(String executionSignature) {
- tester.printStats(getTrainStats(), Util.DRL_DIRECTORY + executionSignature, false);
- tester.printStats(getTestStats(), Util.DRL_DIRECTORY + executionSignature, true);
- }
-
- public void printLatex(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- sb.append( "\n");
- sb.append(getTrainStats().print4Latex() +getTestStats().print4Latex()+"\\\\"+ "\n");
- sb.append( "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
- }
- public void printLatex2(String executionSignature) {
- StringBuffer sb = new StringBuffer();
- sb.append("#"+ Stats.getErrors());
- sb.append( "\n");
- sb.append(train_evaluation2.print4Latex() +test_evaluation2.print4Latex()+"\\\\"+ "\n");
- sb.append( "\n");
-
- StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
- }
+// public void printResults(String executionSignature) {
+// tester.printStats(getTrainStats(), Util.DRL_DIRECTORY + executionSignature, false);
+// tester.printStats(getTestStats(), Util.DRL_DIRECTORY + executionSignature, true);
+// }
+//
+// public void printLatex(String executionSignature) {
+// StringBuffer sb = new StringBuffer();
+// sb.append("#"+ Stats.getErrors());
+// sb.append( "\n");
+// sb.append(getTrainStats().print4Latex() +getTestStats().print4Latex()+"\\\\"+ "\n");
+// sb.append( "\n");
+//
+// StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
+// }
+//
+// public void printLatex2(String executionSignature) {
+// StringBuffer sb = new StringBuffer();
+// sb.append("#"+ Stats.getErrors());
+// sb.append( "\n");
+// sb.append(train_evaluation2.print4Latex() +test_evaluation2.print4Latex()+"\\\\"+ "\n");
+// sb.append( "\n");
+//
+// StatsPrinter.print2file(sb, Util.DRL_DIRECTORY +executionSignature, true);
+// }
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -7,7 +7,7 @@
import org.drools.learner.InstanceList;
import org.drools.learner.Stats;
import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.builder.test.SingleTreeTester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
@@ -103,15 +103,15 @@
error ++;
}
}
- dt.setValidationError(Util.division(error, fold_size));
+ //TODO dt.setTrainError(Util.division(error, fold_size));
dt.calc_num_node_leaves(dt.getRoot());
- if (slog.error() !=null)
- slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+// if (slog.error() !=null)
+// slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
/* moving averages */
validation_error_estimate += ((double)error/(double) fold_size)/(double)k_fold;
- training_error_estimate += ((double)dt.getTrainingError())/(double)k_fold;//((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
+ //TODO training_error_estimate += ((double)dt.getTrainingError())/(double)k_fold;//((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
num_leaves_estimate += (double)dt.getRoot().getNumLeaves()/(double)k_fold;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -168,13 +168,13 @@
if ((num_supp > 0) && !winner.equals(looser)) {
- //System.out.println(Util.ntimes("DANIEL", 2)+ " one looser ? "+looser + " num of sup="+num_supp);
- //System.out.println(" the num of supporters = "+ stats.getVoteFor(looser));
- //System.out.println(" but the guys "+ stats.getSupportersFor(looser));
- //System.out.println("How many bok: "+stats.getSupportersFor(looser).size());
+// System.out.println(Util.ntimes("DANIEL", 2)+ " one looser ? "+looser + " num of sup="+num_supp);
+// System.out.println(" the num of supporters = "+ this.getVoteFor(looser));
+// System.out.println(" but the guys "+ this.getSupportersFor(looser));
+// System.out.println("How many bok: "+this.getSupportersFor(looser).size());
missclassification.addAll(getSupportersFor(looser));
} else {
- //System.out.println(Util.ntimes("DANIEL", 5)+ "how many times matching?? not a looser "+ looser );
+// System.out.println(Util.ntimes("DANIEL", 5)+ "how many times matching?? not a looser "+ looser );
}
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -6,7 +6,7 @@
import org.drools.learner.InstanceList;
import org.drools.learner.Stats;
import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.builder.test.SingleTreeTester;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
@@ -79,15 +79,15 @@
error ++;
}
}
- dt.setValidationError(Util.division(error, test_set.getSize()));
+ //TODO dt.setValidationError(Util.division(error, test_set.getSize()));
dt.calc_num_node_leaves(dt.getRoot());
- if (slog.error() !=null)
- slog.error().log("The estimate of : "+(0)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+// if (slog.error() !=null)
+// slog.error().log("The estimate of : "+(0)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
/* moving averages */
- error_estimate = dt.getValidationError();
- training_error_estimate = (double)dt.getTrainingError();
+ //TODO error_estimate = dt.getValidationError();
+ //TODO training_error_estimate = (double)dt.getTrainingError();
num_leaves_estimate = (double)dt.getRoot().getNumLeaves();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -22,5 +22,8 @@
public int getNumPruned() {
return num_prunned;
}
+ public void setDepth(int max_depth) {
+ limit_depth = max_depth;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/test/java/org/drools/learner/StructuredTestFactory.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -4,8 +4,8 @@
import org.drools.learner.builder.C45Learner;
import org.drools.learner.builder.Learner;
import org.drools.learner.builder.SingleTreeBuilder;
-import org.drools.learner.builder.SingleTreeTester;
import org.drools.learner.builder.Learner.DataType;
+import org.drools.learner.builder.test.SingleTreeTester;
import org.drools.learner.builder.DecisionTreeFactory;
import org.drools.learner.eval.Entropy;
import org.drools.learner.tools.FeatureNotSupported;
@@ -30,12 +30,12 @@
if (BUILD_TREE) {
single_builder.build(mem, learner);//obj_class, target_attr, working_attr
- SingleTreeTester tester = new SingleTreeTester(single_builder.getTree());
+ SingleTreeTester tester = new SingleTreeTester(single_builder.getBestSolution().getTree());
//tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
//Tester.test(c45, mem.getClassInstances());
- single_builder.getTree().setSignature(executionSignature);
+ single_builder.getBestSolution().getTree().setSignature(executionSignature);
}
- return single_builder.getTree();
+ return single_builder.getBestSolution().getTree();
}
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -65,34 +65,17 @@
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createBoostC45E(session, obj_class);
break;
case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
+ decision_tree = DecisionTreeFactory.createBoostC45G(session, obj_class);
break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
-// case 400:
-// decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
-// break;
-// case 500:
-// decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
-// break;
-// case 600:
-// decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
-// break;
-// case 601:
-// decision_tree = DecisionTreeFactory.createSingleTestPrunnedStopC45E(session, obj_class);
-// break;
- case 700:
- decision_tree = DecisionTreeFactory.createSingleC45E_StoppedTest(session, obj_class);
- break;
+
case 701:
- decision_tree = DecisionTreeFactory.createBaggC45E_StoppedTest(session, obj_class);
+ decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
break;
case 702:
- decision_tree = DecisionTreeFactory.createBoostedC45E_StopTest(session, obj_class);
+ decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -67,24 +67,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
-// case 400:
-// // decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
-// break;
- case 500:
- decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
- break;
-// case 600:
-// // decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/MannersLearnerBenchmark.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/MannersLearnerBenchmark.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/MannersLearnerBenchmark.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -71,18 +71,7 @@
case 221:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
- case 222:
- decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
- break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -64,15 +64,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -48,7 +48,7 @@
// }
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 700;
+ DecisionTree decision_tree; int ALGO = 141;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -67,45 +67,54 @@
case 122:
decision_tree = DecisionTreeFactory.createSingleC45G(session, obj_class);
break;
+ case 131:
+ decision_tree = DecisionTreeFactory.createSingleC45E_Stop(session, obj_class);
+ break;
+ case 132:
+ decision_tree = DecisionTreeFactory.createSingleC45G_Stop(session, obj_class);
+ break;
+ case 141:
+ decision_tree = DecisionTreeFactory.createSingleC45E_PrunStop(session, obj_class);
+ break;
+ case 142:
+ decision_tree = DecisionTreeFactory.createSingleC45G_PrunStop(session, obj_class);
+ break;
case 221:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
+ case 231:
+ decision_tree = DecisionTreeFactory.createBagC45E_Stop(session, obj_class);
+ break;
+ case 232:
+ decision_tree = DecisionTreeFactory.createBagC45G_Stop(session, obj_class);
+ break;
+ case 241:
+ decision_tree = DecisionTreeFactory.createBagC45E_PrunStop(session, obj_class);
+ break;
+ case 242:
+ decision_tree = DecisionTreeFactory.createBagC45G_PrunStop(session, obj_class);
+ break;
case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
+ decision_tree = DecisionTreeFactory.createBoostC45E(session, obj_class);
+ break;
case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
+ decision_tree = DecisionTreeFactory.createBoostC45G(session, obj_class);
+ break;
+ case 331:
+ decision_tree = DecisionTreeFactory.createBoostC45E_Stop(session, obj_class);
+ break;
+ case 332:
+ decision_tree = DecisionTreeFactory.createBoostC45G_Stop(session, obj_class);
+ break;
+ case 341:
+ decision_tree = DecisionTreeFactory.createBoostC45E_PrunStop(session, obj_class);
+ break;
+ case 342:
+ decision_tree = DecisionTreeFactory.createBoostC45G_PrunStop(session, obj_class);
break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
- case 400:
- decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
- break;
- case 600:
- decision_tree = DecisionTreeFactory.createSingleTestPrunnedC45E(session, obj_class);
- break;
- case 700:
- decision_tree = DecisionTreeFactory.createSingleC45E_Test(session, obj_class);
- break;
- case 701:
- decision_tree = DecisionTreeFactory.createBaggC45E_Test(session, obj_class);
- break;
- case 702:
- decision_tree = DecisionTreeFactory.createBoostedC45E_Test(session, obj_class);
- break;
- case 710:
- decision_tree = DecisionTreeFactory.createSingleC45E_StoppedTest(session, obj_class);
- break;
- case 711:
- decision_tree = DecisionTreeFactory.createBaggC45E_StoppedTest(session, obj_class);
- break;
- case 712:
- decision_tree = DecisionTreeFactory.createBoostedC45E_StopTest(session, obj_class);
- break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
@@ -120,7 +129,7 @@
*/
ruleBase.addPackage( builder.getPackage() );
- session.fireAllRules();
+ //session.fireAllRules();
long end_time = System.currentTimeMillis();
System.out.println("Total time="+ (end_time-start_time));
ReteStatistics stats = new ReteStatistics(ruleBase);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -65,15 +65,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/ShoppingExm.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/ShoppingExm.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/ShoppingExm.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -46,9 +46,7 @@
case 2:
decision_tree = DecisionTreeFactory.createBagC45E(session, obj_class);
break;
- case 3:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
+
default:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredCarExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -66,27 +66,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
-// case 400:
-// decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
-// break;
-// case 500:
-// decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
-// break;
-// case 600:
-// decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
-// break;
-// case 601:
-// decision_tree = DecisionTreeFactory.createSingleTestPrunnedStopC45E(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/StructuredNurseryExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -66,27 +66,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
-// case 400:
-// decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
-// break;
-// case 500:
-// decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
-// break;
-// case 600:
-// decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
-// break;
-// case 601:
-// decision_tree = DecisionTreeFactory.createSingleTestPrunnedStopC45E(session, obj_class);
-// break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-22 01:05:47 UTC (rev 21669)
@@ -63,30 +63,7 @@
case 222:
decision_tree = DecisionTreeFactory.createBagC45G(session, obj_class);
break;
- case 321:
- decision_tree = DecisionTreeFactory.createBoostedC45E(session, obj_class);
- break;
- case 322:
- decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
- break;
- case 500:
- decision_tree = DecisionTreeFactory.createSingleC45E_Stopped(session, obj_class);
- break;
-// case 600:
-// decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
-// break;
-// case 3:
-// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
-// break;
- case 700:
- decision_tree = DecisionTreeFactory.createSingleC45E_StoppedTest(session, obj_class);
- break;
- case 701:
- decision_tree = DecisionTreeFactory.createBaggC45E_StoppedTest(session, obj_class);
- break;
- case 702:
- decision_tree = DecisionTreeFactory.createBoostedC45E_StopTest(session, obj_class);
- break;
+
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.drl
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.drl 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.drl 2008-08-22 01:05:47 UTC (rev 21669)
@@ -2,40 +2,40 @@
import org.drools.examples.learner.Golf
-rule "#0 decision= Play classifying 4.0 num of facts with rank:0.2857142857142857"
+rule "#1 decision= Play classifying 4.0 num of facts with rank:Infinity"
when
$golf_0 : Golf(outlook == "overcast", $target_label : decision )
then
System.out.println("[decision] Expected value (" + $target_label + "), Classified as (Play )");
end
-rule "#3 decision= Play classifying 3.0 num of facts with rank:0.21428571428571427"
+rule "#2 decision= Play classifying 1.0 num of facts with rank:Infinity"
when
- $golf_0 : Golf(outlook == "rain", windy == false, $target_label : decision )
+ $golf_0 : Golf(outlook == "sunny", humidity <= 80, $target_label : decision )
then
System.out.println("[decision] Expected value (" + $target_label + "), Classified as (Play )");
end
-rule "#4 decision= Don't Play classifying 3.0 num of facts with rank:0.21428571428571427"
+rule "#3 decision= Don't Play classifying 2.0 num of facts with rank:Infinity"
when
- $golf_0 : Golf(outlook == "sunny", humidity > 77, $target_label : decision )
+ $golf_0 : Golf(outlook == "rain", windy == true, $target_label : decision )
then
System.out.println("[decision] Expected value (" + $target_label + "), Classified as (Don't Play )");
end
-rule "#1 decision= Play classifying 2.0 num of facts with rank:0.14285714285714285"
+rule "#4 decision= Play classifying 2.0 num of facts with rank:Infinity"
when
- $golf_0 : Golf(outlook == "sunny", humidity <= 77, $target_label : decision )
+ $golf_0 : Golf(outlook == "rain", windy == false, $target_label : decision )
then
System.out.println("[decision] Expected value (" + $target_label + "), Classified as (Play )");
end
-rule "#2 decision= Don't Play classifying 2.0 num of facts with rank:0.14285714285714285"
+rule "#5 decision= Don't Play classifying 2.0 num of facts with rank:Infinity"
when
- $golf_0 : Golf(outlook == "rain", windy == true, $target_label : decision )
+ $golf_0 : Golf(outlook == "sunny", humidity > 80, $target_label : decision )
then
System.out.println("[decision] Expected value (" + $target_label + "), Classified as (Don't Play )");
end
-//THE END: Total number of facts correctly classified= 14 over 14.0
+//THE END: Total number of facts correctly classified= 11 over 11.0
//with 5 number of rules over 5 total number of rules
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.stats
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.stats 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/golf_c45_one.stats 2008-08-22 01:05:47 UTC (rev 21669)
@@ -1,4 +1,13 @@
-TESTING results: incorrect 0
-TESTING results: correct 14
-TESTING results: unknown 0
-TESTING results: Total Number 14
\ No newline at end of file
+#INCORRECT CORRECT TOTAL
+
+ & 0 & 0 & 11 & 100 & 11 & 100 & 0 & 0 & 3 & 100 & 3 & 100\\
+
+
+#INCORRECT CORRECT TOTAL
+
+ & 0 & 0 & 11 & 100 & 11 & 100 & 0 & 0 & 3 & 100 & 3 & 100\\
+
+
+ & OBJECT_TYPE_NODE & ALPHA_NODE & BETA_NODE & TERMINAL_NODE
+ & 1 & 7 & 0 & 5\\
+
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.drl
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.drl 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.drl 2008-08-22 01:05:47 UTC (rev 21669)
@@ -2,54 +2,54 @@
import org.drools.examples.learner.Restaurant
-rule "#7 will_wait= true classifying 3.0 num of facts with rank:0.2727272727272727"
+rule "#2 will_wait= false classifying 2.0 num of facts with rank:Infinity"
when
- $restaurant_0 : Restaurant(patrons == "Some", $target_label : will_wait )
- then
- System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (true )");
-end
-
-rule "#1 will_wait= false classifying 2.0 num of facts with rank:0.18181818181818182"
- when
$restaurant_0 : Restaurant(patrons == "None", $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (false )");
end
-rule "#4 will_wait= false classifying 2.0 num of facts with rank:0.18181818181818182"
+rule "#3 will_wait= false classifying 1.0 num of facts with rank:Infinity"
when
- $restaurant_0 : Restaurant(patrons == "Full", hungry == false, $target_label : will_wait )
+ $restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Italian", $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (false )");
end
-rule "#2 will_wait= false classifying 1.0 num of facts with rank:0.09090909090909091"
+rule "#4 will_wait= false classifying 1.0 num of facts with rank:Infinity"
when
- $restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Italian", $target_label : will_wait )
+ $restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Thai", fri_sat == false, $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (false )");
end
-rule "#3 will_wait= false classifying 1.0 num of facts with rank:0.09090909090909091"
+rule "#5 will_wait= false classifying 2.0 num of facts with rank:Infinity"
when
- $restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Thai", fri_sat == false, $target_label : will_wait )
+ $restaurant_0 : Restaurant(patrons == "Full", hungry == false, $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (false )");
end
-rule "#5 will_wait= true classifying 1.0 num of facts with rank:0.09090909090909091"
+rule "#6 will_wait= true classifying 1.0 num of facts with rank:Infinity"
when
$restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Thai", fri_sat == true, $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (true )");
end
-rule "#6 will_wait= true classifying 1.0 num of facts with rank:0.09090909090909091"
+rule "#7 will_wait= true classifying 1.0 num of facts with rank:Infinity"
when
$restaurant_0 : Restaurant(patrons == "Full", hungry == true, type == "Burger", $target_label : will_wait )
then
System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (true )");
end
-//THE END: Total number of facts correctly classified= 11 over 11.0
+rule "#8 will_wait= true classifying 1.0 num of facts with rank:Infinity"
+ when
+ $restaurant_0 : Restaurant(patrons == "Some", $target_label : will_wait )
+ then
+ System.out.println("[will_wait] Expected value (" + $target_label + "), Classified as (true )");
+end
+
+//THE END: Total number of facts correctly classified= 9 over 9.0
//with 7 number of rules over 8 total number of rules
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.stats
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.stats 2008-08-21 19:48:02 UTC (rev 21668)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/restaurant_id3_one.stats 2008-08-22 01:05:47 UTC (rev 21669)
@@ -1,4 +1,5 @@
-TESTING results: incorrect 0
-TESTING results: correct 11
-TESTING results: unknown 0
-TESTING results: Total Number 11
\ No newline at end of file
+#INCORRECT CORRECT TOTAL
+
+ & 0 & 0 & 9 & 100 & 9 & 100 & 1 & 50 & 1 & 50 & 2 & 100\\
+
+
More information about the jboss-svn-commits
mailing list