[jboss-svn-commits] JBL Code SVN: r21142 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Mon Jul 21 08:25:31 EDT 2008
Author: gizil
Date: 2008-07-21 08:25:31 -0400 (Mon, 21 Jul 2008)
New Revision: 21142
Modified:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.drl
Log:
updates for the pruner
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -30,7 +30,7 @@
private String execution_signature;
public long FACTS_READ = 0;
- private int validation_error, training_error;
+ private double validation_error, training_error;
private int num_nonterminal_nodes;
@@ -116,17 +116,17 @@
}
- public void setValidationError(int error) {
+ public void setValidationError(double error) {
validation_error = error;
}
- public int getValidationError() {
+ public double getValidationError() {
return validation_error;
}
- public void setTrainingError(int error) {
+ public void setTrainingError(double error) {
training_error = error;
}
- public int getTrainingError() {
+ public double getTrainingError() {
return training_error;
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -7,6 +7,7 @@
import org.drools.learner.eval.Estimator;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
@@ -17,159 +18,171 @@
private Estimator procedure;
- private ArrayList<ArrayList<NodeUpdate>> updates;
+ private TreeStats best_stats;
+
private int num_trees_to_grow;
- private NodeUpdate best_update;
public DecisionTreePruner(Estimator proc) {
procedure = proc;
num_trees_to_grow = procedure.getEstimatorSize();
- updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
+ //updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- best_update = new NodeUpdate(10000000.0d);
+ best_stats = new TreeStats(proc.getAlphaEstimate());
}
- public void prun_tree(DecisionTree tree) {
- // TODO Auto-generated method stub
+ public void prun_to_estimate() {
+ ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- }
-
- public void prun_to_estimate() {
+ ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
+// private NodeUpdate best_update;
for (DecisionTree dt: procedure.getEstimators()) {
// dt.getId()
//dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
- ArrayList<NodeUpdate> tree_sequence = new ArrayList<NodeUpdate>();
- updates.add(tree_sequence);
- NodeUpdate init_tree = new NodeUpdate(dt.getValidationError());
- init_tree.setResubstitution_cost(dt.getTrainingError());
- init_tree.setAlpha(-1); // dont know
- init_tree.setCost_complexity(-1); // dont known
- init_tree.setDecisionTree(dt);
- init_tree.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
- tree_sequence.add(init_tree);
- sequence_trees(dt);
+ TreeSequenceProc search = new TreeSequenceProc(dt, 100000.0d, new MinAlphaProc());
+
+ search.iterate_trees(0);
+ //updates.add(tree_sequence);
+ updates.add(search.getTreeSequence());
+ sequence_stats.add(search.getTreeSequenceStats());
+
// sort the found candidates
//Collections.sort(updates.get(dt.getId()), arg1)
int id =0;
System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
for (NodeUpdate nu: updates.get(dt.getId()) ){
- //System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- System.out.println(id +"\t"+ nu.getNum_terminal_nodes()+"\t"+nu.getCross_validated_cost()+"\t"+nu.getResubstitution_cost()+"\t"+nu.getAlpha()+"\n");
+ //TODO What to print here?
+ //System.out.println(id +"\t"+ nu.getNum_terminal_nodes()+"\t"+nu.getCross_validated_cost()+"\t"+nu.getResubstitution_cost()+"\t"+nu.getAlpha()+"\n");
id++;
}
+
+ int sid =0;
+ System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
+ for (TreeStats st: sequence_stats.get(dt.getId()) ){
+ //System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
+ System.out.println(sid);
+ sid++;
+ }
+ int x =0;
}
}
- private void sequence_trees(DecisionTree dt_0) {
- if (slog.debug() !=null)
- slog.debug().log(dt_0.toString() +"\n");
- // if number of non-terminal nodes in the tree is more than 1
- // = if there exists at least one non-terminal node different than root
- if (dt_0.getNumNonTerminalNodes() < 1) {
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + dt_0.getNumNonTerminalNodes() +"\n");
- return;
- } else if (dt_0.getNumNonTerminalNodes() == 1 && dt_0.getRoot().getNumLeaves()<=1) {
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + dt_0.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +dt_0.getRoot().getNumLeaves()+"\n");
- return;
- }
- // for each non-leaf subtree
- ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
- // to find all candidates with min_alpha value
- TreeSequenceProc search = new TreeSequenceProc(100000.0d, new MinAlphaProc());
+ public void prun_tree(DecisionTree tree) {
+ TreeSequenceProc search = new TreeSequenceProc(tree, best_stats.getAlpha(), new AnAlphaProc());
+ search.iterate_trees(0);
+ //search.getTreeSequence()// to go back
-
- search.find_candidate_nodes(dt_0, dt_0.getRoot(), candidate_nodes);
- double min_alpha = search.getTheAlpha();
- System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
-
- if (candidate_nodes.size() >0) {
- // The one or more subtrees with that value of will be replaced by leaves
- // instead of getting the first node, have to process all nodes and prune all
- // write a method to prune all
- TreeNode best_node = candidate_nodes.get(0);
- LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
- best_clone.setRank( best_node.getRank());
- best_clone.setNumMatch(best_node.getNumMatch()); //num of matching instances to the leaf node
- best_clone.setNumClassification(best_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
-
- NodeUpdate update = new NodeUpdate(best_node, best_clone);
- //update.set
- update.setAlpha(min_alpha);
- update.setDecisionTree(dt_0);
- int k = numExtraMisClassIfPrun(best_node); // extra misclassified guys
- int num_leaves = best_node.getNumLeaves();
- int new_num_leaves = dt_0.getRoot().getNumLeaves() - num_leaves +1;
-
- TreeNode father_node = best_node.getFather();
- if (father_node != null) {
- for(Object key: father_node.getChildrenKeys()) {
- if (father_node.getChild(key).equals(best_node)) {
- father_node.putNode(key, best_clone);
- break;
- }
- }
- updateLeaves(father_node, -num_leaves+1);
- } else {
- // this node does not have any father node it is the root node of the tree
- dt_0.setRoot(best_clone);
- }
-
-
- ArrayList<InstanceList> sets = procedure.getFold(dt_0.getId());
- //InstanceList learning_set = sets.get(0);
- InstanceList validation_set = sets.get(1);
-
- int error = 0;
- SingleTreeTester t= new SingleTreeTester(dt_0);
- for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
- Integer result = t.test(validation_set.getInstance(index_i));
- if (result == Stats.INCORRECT) {
- error ++;
- }
- }
-
-
- update.setCross_validated_cost(error);
- int new_resubstitution_cost = dt_0.getTrainingError() + k;
- double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
-
-
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
- update.setResubstitution_cost(new_resubstitution_cost);
- // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
- update.setCost_complexity(cost_complexity);
- update.setNum_terminal_nodes(new_num_leaves);
-
- updates.get(dt_0.getId()).add(update);
-
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:error "+ error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
-
- if (error < procedure.getValidationErrorEstimate() * 1.6) {
- // if the error of the tree is not that bad
-
- if (error < best_update.getCross_validated_cost()) {
- best_update = update;
- if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:best node updated \n");
-
- }
-
- sequence_trees(dt_0);
- } else {
- update.setStopTree();
- return;
- }
- }
-
}
+
+// private void sequence_trees(DecisionTree dt_0, double init_alpha, AlphaSelectionProc proc) {
+// if (slog.debug() !=null)
+// slog.debug().log(dt_0.toString() +"\n");
+// // if number of non-terminal nodes in the tree is more than 1
+// // = if there exists at least one non-terminal node different than root
+// if (dt_0.getNumNonTerminalNodes() < 1) {
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + dt_0.getNumNonTerminalNodes() +"\n");
+// return;
+// } else if (dt_0.getNumNonTerminalNodes() == 1 && dt_0.getRoot().getNumLeaves()<=1) {
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + dt_0.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +dt_0.getRoot().getNumLeaves()+"\n");
+// return;
+// }
+// // for each non-leaf subtree
+// ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+// // to find all candidates with min_alpha value
+// TreeSequenceProc search = new TreeSequenceProc(init_alpha, proc);//100000.0d, new MinAlphaProc());
+//
+//
+// search.find_candidate_nodes(dt_0, dt_0.getRoot(), candidate_nodes);
+// double min_alpha = search.getTheAlpha();
+// System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+//
+// if (candidate_nodes.size() >0) {
+// // The one or more subtrees with that value of will be replaced by leaves
+// // instead of getting the first node, have to process all nodes and prune all
+// // write a method to prune all
+// TreeNode best_node = candidate_nodes.get(0);
+// LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
+// best_clone.setRank( best_node.getRank());
+// best_clone.setNumMatch(best_node.getNumMatch()); //num of matching instances to the leaf node
+// best_clone.setNumClassification(best_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
+//
+// NodeUpdate update = new NodeUpdate(best_node, best_clone);
+// //update.set
+// update.setAlpha(min_alpha);
+// update.setDecisionTree(dt_0);
+// int k = numExtraMisClassIfPrun(best_node); // extra misclassified guys
+// int num_leaves = best_node.getNumLeaves();
+// int new_num_leaves = dt_0.getRoot().getNumLeaves() - num_leaves +1;
+//
+// TreeNode father_node = best_node.getFather();
+// if (father_node != null) {
+// for(Object key: father_node.getChildrenKeys()) {
+// if (father_node.getChild(key).equals(best_node)) {
+// father_node.putNode(key, best_clone);
+// break;
+// }
+// }
+// updateLeaves(father_node, -num_leaves+1);
+// } else {
+// // this node does not have any father node it is the root node of the tree
+// dt_0.setRoot(best_clone);
+// }
+//
+//
+// ArrayList<InstanceList> sets = procedure.getFold(dt_0.getId());
+// //InstanceList learning_set = sets.get(0);
+// InstanceList validation_set = sets.get(1);
+//
+// int error = 0;
+// SingleTreeTester t= new SingleTreeTester(dt_0);
+// for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
+// Integer result = t.test(validation_set.getInstance(index_i));
+// if (result == Stats.INCORRECT) {
+// error ++;
+// }
+// }
+//
+//
+// update.setCross_validated_cost(error);
+// int new_resubstitution_cost = dt_0.getTrainingError() + k;
+// double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
+//
+//
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
+// update.setResubstitution_cost(new_resubstitution_cost);
+// // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
+// update.setCost_complexity(cost_complexity);
+// update.setNum_terminal_nodes(new_num_leaves);
+//
+// updates.get(dt_0.getId()).add(update);
+//
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:error "+ error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
+//
+// if (error < procedure.getValidationErrorEstimate() * 1.6) {
+// // if the error of the tree is not that bad
+//
+// if (error < best_update.getCross_validated_cost()) {
+// best_update = update;
+// if (slog.debug() !=null)
+// slog.debug().log(":sequence_trees:best node updated \n");
+//
+// }
+//
+// sequence_trees(dt_0);
+// } else {
+// update.setStopTree();
+// return;
+// }
+// }
+//
+// }
private void updateLeaves(TreeNode my_node, int i) {
my_node.setNumLeaves(my_node.getNumLeaves() + i);
@@ -201,16 +214,172 @@
public class TreeSequenceProc {
+
+ private static final double MAX_ERROR_RATIO = 0.99;
+ private DecisionTree focus_tree;
private double the_alpha;
private AlphaSelectionProc alpha_proc;
+ private ArrayList<NodeUpdate> tree_sequence;
+ private ArrayList<TreeStats> tree_sequence_stats;
- public TreeSequenceProc(double value, AlphaSelectionProc cond) {
- the_alpha = value;
+ private TreeStats best_tree_stats;
+ public TreeSequenceProc(DecisionTree dt, double init_alpha, AlphaSelectionProc cond) {
+ focus_tree = dt;
+ the_alpha = init_alpha;
alpha_proc = cond;
+ tree_sequence = new ArrayList<NodeUpdate>();
+
+ best_tree_stats = new TreeStats(10000000.0d);
+
+// init_tree.setResubstitution_cost(dt.getTrainingError());
+// init_tree.setAlpha(-1); // dont know
+// init_tree.setCost_complexity(-1); // dont known
+// init_tree.setDecisionTree(dt);
+// init_tree.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
+
+ NodeUpdate init_tree = new NodeUpdate(dt.getValidationError());
+ tree_sequence.add(init_tree);
+
+ TreeStats init_tree_stats = new TreeStats(dt.getValidationError());
+ init_tree_stats.setResubstitution_cost(dt.getTrainingError());
+ init_tree_stats.setAlpha(-1); // dont know
+ init_tree_stats.setCost_complexity(-1); // dont known
+// init_tree_stats.setDecisionTree(dt);
+ init_tree_stats.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
+ tree_sequence_stats.add(init_tree_stats);
+
}
+ public ArrayList<TreeStats> getTreeSequenceStats() {
+ return tree_sequence_stats;
+ }
+
+ public ArrayList<NodeUpdate> getTreeSequence() {
+ return tree_sequence;
+ }
+
+ private void iterate_trees(int i) {
+ if (slog.debug() !=null)
+ slog.debug().log(focus_tree.toString() +"\n");
+ // if number of non-terminal nodes in the tree is more than 1
+ // = if there exists at least one non-terminal node different than root
+ if (focus_tree.getNumNonTerminalNodes() < 1) {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + focus_tree.getNumNonTerminalNodes() +"\n");
+ return;
+ } else if (focus_tree.getNumNonTerminalNodes() == 1 && focus_tree.getRoot().getNumLeaves()<=1) {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + focus_tree.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +focus_tree.getRoot().getNumLeaves()+"\n");
+ return;
+ }
+ // for each non-leaf subtree
+ ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+ // to find all candidates with min_alpha value
+ //TreeSequenceProc search = new TreeSequenceProc(the_alpha, alpha_proc);//100000.0d, new MinAlphaProc());
+
+
+ find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
+ //double min_alpha = search.getTheAlpha();
+ double min_alpha = getTheAlpha();
+ System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+
+ if (candidate_nodes.size() >0) {
+ // The one or more subtrees with that value of will be replaced by leaves
+ // instead of getting the first node, have to process all nodes and prune all
+ // write a method to prune all
+ //TreeNode best_node = candidate_nodes.get(0);
+ TreeStats stats = new TreeStats();
+ stats.iteration_id(i);
+
+ int change_in_training_misclass = 0; // (k)
+ for (TreeNode candidate_node:candidate_nodes) {
+ LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
+ best_clone.setRank( candidate_node.getRank());
+ best_clone.setNumMatch(candidate_node.getNumMatch()); //num of matching instances to the leaf node
+ best_clone.setNumClassification(candidate_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
+
+ NodeUpdate update = new NodeUpdate(candidate_node, best_clone);
+ update.iteration_id(i);
+
+ update.setDecisionTree(focus_tree);
+ change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+ int num_leaves = candidate_node.getNumLeaves();
+
+ TreeNode father_node = candidate_node.getFather();
+ if (father_node != null) {
+ for(Object key: father_node.getChildrenKeys()) {
+ if (father_node.getChild(key).equals(candidate_node)) {
+ father_node.putNode(key, best_clone);
+ break;
+ }
+ }
+ updateLeaves(father_node, -num_leaves+1);
+ } else {
+ // this node does not have any father node it is the root node of the tree
+ focus_tree.setRoot(best_clone);
+ }
+
+ //updates.get(dt_0.getId()).add(update);
+ tree_sequence.add(update);
+
+ }
+
+ ArrayList<InstanceList> sets = procedure.getFold(focus_tree.getId());
+ //InstanceList learning_set = sets.get(0);
+ InstanceList validation_set = sets.get(1);
+
+ int num_error = 0;
+ SingleTreeTester t= new SingleTreeTester(focus_tree);
+ for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
+ Integer result = t.test(validation_set.getInstance(index_i));
+ if (result == Stats.INCORRECT) {
+ num_error ++;
+ }
+ }
+ double percent_error = Util.division(num_error, validation_set.getSize());
+
+ int new_num_leaves = focus_tree.getRoot().getNumLeaves();
+
+ double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, focus_tree.FACTS_READ);
+ double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
+
+
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
+
+
+ stats.setAlpha(min_alpha);
+ stats.setCross_validated_cost(percent_error);
+ stats.setResubstitution_cost(new_resubstitution_cost);
+ // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
+ stats.setCost_complexity(cost_complexity);
+ stats.setNum_terminal_nodes(new_num_leaves);
+ tree_sequence_stats.add(stats);
+
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
+
+ if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
+ // if the error of the tree is not that bad
+
+ if (percent_error < best_tree_stats.getCross_validated_cost()) {
+ best_tree_stats = stats;
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:best node updated \n");
+
+ }
+
+ iterate_trees(i+1);
+ } else {
+ //TODO update.setStopTree();
+ return;
+ }
+ }
+
+ }
+
// memory optimized
- public void find_candidate_nodes(DecisionTree dt, TreeNode my_node, ArrayList<TreeNode> nodes) {
+ public void find_candidate_nodes(TreeNode my_node, ArrayList<TreeNode> nodes) {
if (my_node instanceof LeafNode) {
@@ -222,15 +391,15 @@
int k = numExtraMisClassIfPrun(my_node);
int num_leaves = my_node.getNumLeaves();
- double alpha = ((double)k)/((double)dt.FACTS_READ * (num_leaves-1));
+ double alpha = ((double)k)/((double)focus_tree.FACTS_READ * (num_leaves-1));
if (slog.debug() !=null)
- slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+the_alpha+ " k "+k+" num_leaves "+num_leaves+" all "+ dt.FACTS_READ + "\n");
+ slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+the_alpha+ " k "+k+" num_leaves "+num_leaves+" all "+ focus_tree.FACTS_READ + "\n");
the_alpha = alpha_proc.update_nodes(alpha, the_alpha, my_node, nodes);
for (Object attributeValue : my_node.getChildrenKeys()) {
TreeNode child = my_node.getChild(attributeValue);
- find_candidate_nodes(dt, child, nodes);
+ find_candidate_nodes(child, nodes);
//nodes.pop();
}
}
@@ -242,9 +411,43 @@
}
}
+ public boolean isChildOf(TreeNode cur_node, TreeNode found_node) {
+ TreeNode parent = cur_node.getFather();
+ if (parent == null)
+ return false;
+ if (parent.equals(found_node))
+ return true;
+ else
+ return isChildOf(parent, found_node);
+ }
+
public interface AlphaSelectionProc {
public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
}
+
+ public class AnAlphaProc implements AlphaSelectionProc{
+// private double an_alpha;
+// public AnAlphaProc(double value) {
+// an_alpha = value;
+// }
+
+ public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
+ if (cur_alpha == the_alpha) {
+ for(TreeNode parent:nodes) {
+ if (isChildOf(cur_node, parent))
+ return cur_alpha;// it is not added
+ }
+ // add this one to the set
+ nodes.add(cur_node);
+ }
+ return cur_alpha;
+ }
+
+// public double getAlpha() {
+// return an_alpha;
+// }
+ }
+
public class MinAlphaProc implements AlphaSelectionProc{
// private double best_alpha;
// public MinAlphaProc(double value) {
@@ -253,6 +456,10 @@
public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
if (cur_alpha == the_alpha) {
+ for(TreeNode parent:nodes) {
+ if (isChildOf(cur_node, parent))
+ return cur_alpha;// it is not added
+ }
// add this one to the set
nodes.add(cur_node);
return cur_alpha;
@@ -276,26 +483,6 @@
// }
}
- public class AnAlphaProc implements AlphaSelectionProc{
-// private double an_alpha;
-// public AnAlphaProc(double value) {
-// an_alpha = value;
-// }
-
- public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
-
- if (cur_alpha == the_alpha) {
- // add this one to the set
- nodes.add(cur_node);
- }
- return cur_alpha;
- }
-
-// public double getAlpha() {
-// return an_alpha;
-// }
- }
-
public class NodeUpdate{
private boolean stopTree;
@@ -303,15 +490,20 @@
private DecisionTree tree;
private TreeNode old_node, node_update;
- private int num_terminal_nodes;
- private double cross_validated_cost;
- private double resubstitution_cost;
- private double cost_complexity;
- private double alpha;
+ private int iteration_id;
+ //public TreeStats stats;
+// private int num_terminal_nodes;
+// private double cross_validated_cost;
+// private double resubstitution_cost;
+// private double cost_complexity;
+// private double alpha;
+
// to set an node update with the worst cross validated error
public NodeUpdate(double error) {
- cross_validated_cost = error;
+ //stats = new TreeStats(error);
+ //cross_validated_cost = error;
+
stopTree = false;
old_node = null;
node_update = null;
@@ -321,6 +513,9 @@
old_node = old_n;
node_update = new_n;
}
+ public void iteration_id(int i) {
+ iteration_id = i;
+ }
public void setDecisionTree(DecisionTree dt_0) {
tree = dt_0;
@@ -330,6 +525,29 @@
stopTree = true;
}
+
+ }
+ public class TreeStats{
+
+ private int iteration_id;
+ private int num_terminal_nodes;
+ private double cross_validated_cost;
+ private double resubstitution_cost;
+ private double cost_complexity;
+ private double alpha;
+
+ public TreeStats() {
+ iteration_id = 0;
+ }
+ // to set an node update with the worst cross validated error
+ public TreeStats(double error) {
+ iteration_id = 0;
+ cross_validated_cost = error;
+ }
+
+ public void iteration_id(int i) {
+ iteration_id = i;
+ }
public int getNum_terminal_nodes() {
return num_terminal_nodes;
}
@@ -368,7 +586,9 @@
public void setAlpha(double alpha) {
this.alpha = alpha;
- }
+ }
+
}
+
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -66,7 +66,7 @@
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
- dt.setTrainingError((int) (dt.getTrainingError() + data_stats.getSum()));
+ dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/dt.FACTS_READ);
return noAttributeLeftNode;
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -102,14 +102,14 @@
error ++;
}
}
- dt.setValidationError(error);
+ dt.setValidationError(Util.division(error, fold_size));
dt.calc_num_node_leaves(dt.getRoot());
if (slog.error() !=null)
slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + error +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
- validation_error_estimate += (double)error/(double)k_fold;
- training_error_estimate += (double)dt.getTrainingError()/(double)k_fold;
+ validation_error_estimate += ((double)error/(double) fold_size)/(double)k_fold;
+ training_error_estimate += ((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
num_leaves_estimate += (double)dt.getRoot().getNumLeaves()/(double)k_fold;
@@ -190,7 +190,9 @@
//
// }
-
+ public double getAlphaEstimate() {
+ return alpha_estimate;
+ }
private int getFoldSize(int i) {
int excess = num_instances % k_fold;
return (int) num_instances/k_fold + (i < excess? 1:0);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -12,4 +12,5 @@
public ArrayList<InstanceList> getFold(int id);
public double getValidationErrorEstimate();
+ public double getAlphaEstimate();
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-07-21 12:25:31 UTC (rev 21142)
@@ -59,6 +59,13 @@
return Math.exp(prob);
}
+ public static double division(int x, int y) {
+ return (double)x/(double)y;
+ }
+ public static double division(int x, long y) {
+ return (double)x/(double)y;
+ }
+
/* TODO make this all_fields arraylist as hashmap */
public static void getSuperFields(Class<?> clazz, ArrayList<Field> all_fields) {
if (clazz == Object.class)
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.drl
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.drl 2008-07-21 09:06:47 UTC (rev 21141)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/rules/org/drools/examples/learner/car_c45_one.drl 2008-07-21 12:25:31 UTC (rev 21142)
@@ -2,68 +2,75 @@
import org.drools.examples.learner.Car
-rule "#0 label2= false classifying 432.0 num of facts with rank:0.25"
+rule "#131 target= unacc classifying 576.0 num of facts with rank:0.3333333333333333"
when
- $car_0 : Car(buying == "high", $target_label : label2 )
+ $car_0 : Car(safety == "low", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#4 label2= false classifying 432.0 num of facts with rank:0.25"
+rule "#59 target= unacc classifying 192.0 num of facts with rank:0.1111111111111111"
when
- $car_0 : Car(buying == "vhigh", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "2", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#6 label2= false classifying 432.0 num of facts with rank:0.25"
+rule "#140 target= unacc classifying 192.0 num of facts with rank:0.1111111111111111"
when
- $car_0 : Car(buying == "med", $target_label : label2 )
+ $car_0 : Car(safety == "high", persons == "2", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#5 label2= false classifying 108.0 num of facts with rank:0.0625"
+rule "#145 target= unacc classifying 16.0 num of facts with rank:0.009259259259259259"
when
- $car_0 : Car(buying == "low", doors == "4", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "4", buying == "high", lug_boot == "small", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#7 label2= false classifying 108.0 num of facts with rank:0.0625"
+rule "#172 target= unacc classifying 16.0 num of facts with rank:0.009259259259259259"
when
- $car_0 : Car(buying == "low", doors == "2", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "more", buying == "high", lug_boot == "small", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#8 label2= false classifying 108.0 num of facts with rank:0.0625"
+rule "#8 target= unacc classifying 12.0 num of facts with rank:0.006944444444444444"
when
- $car_0 : Car(buying == "low", doors == "3", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "4", buying == "vhigh", maint == "vhigh", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-rule "#1 label2= true classifying 36.0 num of facts with rank:0.020833333333333332"
+rule "#20 target= acc classifying 12.0 num of facts with rank:0.006944444444444444"
when
- $car_0 : Car(buying == "low", doors == "5more", safety == "med", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "4", buying == "low", maint == "high", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (true )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (acc )");
end
-rule "#2 label2= false classifying 36.0 num of facts with rank:0.020833333333333332"
+rule "#34 target= acc classifying 12.0 num of facts with rank:0.006944444444444444"
when
- $car_0 : Car(buying == "low", doors == "5more", safety == "high", $target_label : label2 )
+ $car_0 : Car(safety == "high", persons == "4", buying == "med", maint == "vhigh", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (acc )");
end
-rule "#3 label2= false classifying 36.0 num of facts with rank:0.020833333333333332"
+rule "#41 target= unacc classifying 12.0 num of facts with rank:0.006944444444444444"
when
- $car_0 : Car(buying == "low", doors == "5more", safety == "low", $target_label : label2 )
+ $car_0 : Car(safety == "med", persons == "more", buying == "vhigh", maint == "vhigh", $target_label : target )
then
- System.out.println("[label2] Expected value (" + $target_label + "), Classified as (false )");
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
end
-//THE END: Total number of facts correctly classified= 1728 over 1728.0
-//with 9 number of rules over 9 total number of rules
+rule "#43 target= unacc classifying 12.0 num of facts with rank:0.006944444444444444"
+ when
+ $car_0 : Car(safety == "med", persons == "4", buying == "vhigh", maint == "high", $target_label : target )
+ then
+ System.out.println("[target] Expected value (" + $target_label + "), Classified as (unacc )");
+end
+
+//THE END: Total number of facts correctly classified= 1052 over 1728.0
+//with 10 number of rules over 10 total number of rules
More information about the jboss-svn-commits
mailing list