[jboss-svn-commits] JBL Code SVN: r21321 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Fri Aug 1 09:12:34 EDT 2008
Author: gizil
Date: 2008-08-01 09:12:34 -0400 (Fri, 01 Aug 2008)
New Revision: 21321
Added:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
Modified:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
Log:
stopping criteria classes
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -135,7 +135,7 @@
return 1;
} else {
- num_nonterminal_nodes ++;
+ num_nonterminal_nodes ++; // TODO does this really work?
}
int leaves = 0;
@@ -150,6 +150,69 @@
return leaves;
}
+ //private ArrayList<LeafNode> leaf_nodes;
+ public ArrayList<LeafNode> getLeaves(TreeNode start_node) {
+ ArrayList<LeafNode> terminal_nodes = new ArrayList<LeafNode>();
+
+ find_leaves(terminal_nodes, start_node);
+
+ return terminal_nodes;
+ }
+
+ private int find_leaves(ArrayList<LeafNode> terminals, TreeNode my_node) {
+ if (my_node instanceof LeafNode) {
+ terminals.add((LeafNode)my_node);
+ return 1;
+ } else {
+ num_nonterminal_nodes ++; // TODO does this really work?
+ }
+
+ int leaves = 0;
+ for (Object child_key: my_node.getChildrenKeys()) {
+ /* split the last two class at the same time */
+
+ TreeNode child = my_node.getChild(child_key);
+ leaves += find_leaves(terminals, child);
+
+ }
+ //my_node.setNumLeaves(leaves);
+ return leaves;
+ }
+
+ public ArrayList<TreeNode> getAnchestor_of_Leaves(TreeNode start_node) {
+ ArrayList<LeafNode> terminal_nodes = new ArrayList<LeafNode>();
+
+ ArrayList<TreeNode> anc_terminal_nodes = new ArrayList<TreeNode>();
+
+ find_leaves(terminal_nodes, anc_terminal_nodes, start_node);
+
+ return anc_terminal_nodes;
+ }
+
+ private int find_leaves(ArrayList<LeafNode> terminals, ArrayList<TreeNode> anchestors, TreeNode my_node) {
+
+ int leaves = 0;
+ boolean anchestor_added = false;
+ for (Object child_key: my_node.getChildrenKeys()) {
+ /* split the last two class at the same time */
+
+ TreeNode child = my_node.getChild(child_key);
+ if (child instanceof LeafNode) {
+ terminals.add((LeafNode)my_node);
+ if (!anchestor_added) {
+ num_nonterminal_nodes ++; // TODO does this really work?
+ anchestors.add(my_node);
+ anchestor_added = true;
+ }
+ return 1;
+ } else {
+ leaves += find_leaves(terminals, anchestors, child);
+ }
+ }
+ //my_node.setNumLeaves(leaves);
+ return leaves;
+ }
+
public int getNumNonTerminalNodes() {
return num_nonterminal_nodes;
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,30 +22,41 @@
private int num_trees_to_grow;
+ private double INIT_ALPHA = 0.5d;
+
public DecisionTreePruner(Estimator proc) {
procedure = proc;
num_trees_to_grow = procedure.getEstimatorSize();
//updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
- best_stats = new TreeStats(proc.getAlphaEstimate());
+ best_stats = new TreeStats(0.0);//proc.getAlphaEstimate());
}
public void prun_to_estimate() {
ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
-
ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
+ ArrayList<MinAlphaProc> alpha_procs = new ArrayList<MinAlphaProc>(num_trees_to_grow);
+
+ /*
+ * The best tree is selected from this series of trees with the classification error not exceeding
+ * an expected error rate on some test set (cross-validation error),
+ * which is done at the second stage.
+ */
+ double value_to_select = procedure.getErrorEstimate();
// private NodeUpdate best_update;
for (DecisionTree dt: procedure.getEstimators()) {
// dt.getId()
//dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
- TreeSequenceProc search = new TreeSequenceProc(dt, 100000.0d, new MinAlphaProc());
+ MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA);
+ TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
search.iterate_trees(0);
//updates.add(tree_sequence);
updates.add(search.getTreeSequence());
sequence_stats.add(search.getTreeSequenceStats());
+ alpha_procs.add(alpha_proc);
// sort the found candidates
//Collections.sort(updates.get(dt.getId()), arg1)
@@ -62,127 +73,32 @@
System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
for (TreeStats st: sequence_stats.get(dt.getId()) ){
//System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- System.out.println(sid);
+ System.out.println(sid+ "" +st.getAlpha() +" "+ st.getTest_cost());
sid++;
+
}
int x =0;
}
+
}
+ public void select_tree () {
+ /*
+ * The best tree is selected from this series of trees with the classification error not exceeding
+ * an expected error rate on some test set (cross-validation error),
+ * which is done at the second stage.
+ */
+ double value_to_select = procedure.getErrorEstimate();
+
+ }
+
public void prun_tree(DecisionTree tree) {
- TreeSequenceProc search = new TreeSequenceProc(tree, best_stats.getAlpha(), new AnAlphaProc());
+ TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha()));
search.iterate_trees(0);
//search.getTreeSequence()// to go back
}
-
-// private void sequence_trees(DecisionTree dt_0, double init_alpha, AlphaSelectionProc proc) {
-// if (slog.debug() !=null)
-// slog.debug().log(dt_0.toString() +"\n");
-// // if number of non-terminal nodes in the tree is more than 1
-// // = if there exists at least one non-terminal node different than root
-// if (dt_0.getNumNonTerminalNodes() < 1) {
-// if (slog.debug() !=null)
-// slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + dt_0.getNumNonTerminalNodes() +"\n");
-// return;
-// } else if (dt_0.getNumNonTerminalNodes() == 1 && dt_0.getRoot().getNumLeaves()<=1) {
-// if (slog.debug() !=null)
-// slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + dt_0.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +dt_0.getRoot().getNumLeaves()+"\n");
-// return;
-// }
-// // for each non-leaf subtree
-// ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
-// // to find all candidates with min_alpha value
-// TreeSequenceProc search = new TreeSequenceProc(init_alpha, proc);//100000.0d, new MinAlphaProc());
-//
-//
-// search.find_candidate_nodes(dt_0, dt_0.getRoot(), candidate_nodes);
-// double min_alpha = search.getTheAlpha();
-// System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
-//
-// if (candidate_nodes.size() >0) {
-// // The one or more subtrees with that value of will be replaced by leaves
-// // instead of getting the first node, have to process all nodes and prune all
-// // write a method to prune all
-// TreeNode best_node = candidate_nodes.get(0);
-// LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
-// best_clone.setRank( best_node.getRank());
-// best_clone.setNumMatch(best_node.getNumMatch()); //num of matching instances to the leaf node
-// best_clone.setNumClassification(best_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
-//
-// NodeUpdate update = new NodeUpdate(best_node, best_clone);
-// //update.set
-// update.setAlpha(min_alpha);
-// update.setDecisionTree(dt_0);
-// int k = numExtraMisClassIfPrun(best_node); // extra misclassified guys
-// int num_leaves = best_node.getNumLeaves();
-// int new_num_leaves = dt_0.getRoot().getNumLeaves() - num_leaves +1;
-//
-// TreeNode father_node = best_node.getFather();
-// if (father_node != null) {
-// for(Object key: father_node.getChildrenKeys()) {
-// if (father_node.getChild(key).equals(best_node)) {
-// father_node.putNode(key, best_clone);
-// break;
-// }
-// }
-// updateLeaves(father_node, -num_leaves+1);
-// } else {
-// // this node does not have any father node it is the root node of the tree
-// dt_0.setRoot(best_clone);
-// }
-//
-//
-// ArrayList<InstanceList> sets = procedure.getFold(dt_0.getId());
-// //InstanceList learning_set = sets.get(0);
-// InstanceList validation_set = sets.get(1);
-//
-// int error = 0;
-// SingleTreeTester t= new SingleTreeTester(dt_0);
-// for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
-// Integer result = t.test(validation_set.getInstance(index_i));
-// if (result == Stats.INCORRECT) {
-// error ++;
-// }
-// }
-//
-//
-// update.setCross_validated_cost(error);
-// int new_resubstitution_cost = dt_0.getTrainingError() + k;
-// double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
-//
-//
-// if (slog.debug() !=null)
-// slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
-// update.setResubstitution_cost(new_resubstitution_cost);
-// // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
-// update.setCost_complexity(cost_complexity);
-// update.setNum_terminal_nodes(new_num_leaves);
-//
-// updates.get(dt_0.getId()).add(update);
-//
-// if (slog.debug() !=null)
-// slog.debug().log(":sequence_trees:error "+ error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
-//
-// if (error < procedure.getValidationErrorEstimate() * 1.6) {
-// // if the error of the tree is not that bad
-//
-// if (error < best_update.getCross_validated_cost()) {
-// best_update = update;
-// if (slog.debug() !=null)
-// slog.debug().log(":sequence_trees:best node updated \n");
-//
-// }
-//
-// sequence_trees(dt_0);
-// } else {
-// update.setStopTree();
-// return;
-// }
-// }
-//
-// }
private void updateLeaves(TreeNode my_node, int i) {
my_node.setNumLeaves(my_node.getNumLeaves() + i);
@@ -217,32 +133,27 @@
private static final double MAX_ERROR_RATIO = 0.99;
private DecisionTree focus_tree;
- private double the_alpha;
+ //private double the_alpha;
private AlphaSelectionProc alpha_proc;
private ArrayList<NodeUpdate> tree_sequence;
private ArrayList<TreeStats> tree_sequence_stats;
private TreeStats best_tree_stats;
- public TreeSequenceProc(DecisionTree dt, double init_alpha, AlphaSelectionProc cond) {
+ public TreeSequenceProc(DecisionTree dt, AlphaSelectionProc cond) { //, double init_alpha
focus_tree = dt;
- the_alpha = init_alpha;
+ //the_alpha = init_alpha;
alpha_proc = cond;
tree_sequence = new ArrayList<NodeUpdate>();
+ tree_sequence_stats = new ArrayList<TreeStats>();
best_tree_stats = new TreeStats(10000000.0d);
-
-// init_tree.setResubstitution_cost(dt.getTrainingError());
-// init_tree.setAlpha(-1); // dont know
-// init_tree.setCost_complexity(-1); // dont known
-// init_tree.setDecisionTree(dt);
-// init_tree.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
NodeUpdate init_tree = new NodeUpdate(dt.getValidationError());
tree_sequence.add(init_tree);
TreeStats init_tree_stats = new TreeStats(dt.getValidationError());
init_tree_stats.setResubstitution_cost(dt.getTrainingError());
- init_tree_stats.setAlpha(-1); // dont know
+ init_tree_stats.setAlpha(0.0d); // dont know
init_tree_stats.setCost_complexity(-1); // dont known
// init_tree_stats.setDecisionTree(dt);
init_tree_stats.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
@@ -281,7 +192,7 @@
find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
//double min_alpha = search.getTheAlpha();
double min_alpha = getTheAlpha();
- System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+ System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
if (candidate_nodes.size() >0) {
// The one or more subtrees with that value of will be replaced by leaves
@@ -340,7 +251,7 @@
int new_num_leaves = focus_tree.getRoot().getNumLeaves();
- double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, focus_tree.FACTS_READ);
+ double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/);
double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
@@ -349,7 +260,7 @@
stats.setAlpha(min_alpha);
- stats.setCross_validated_cost(percent_error);
+ stats.setTest_cost(percent_error);
stats.setResubstitution_cost(new_resubstitution_cost);
// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
stats.setCost_complexity(cost_complexity);
@@ -357,23 +268,27 @@
tree_sequence_stats.add(stats);
if (slog.debug() !=null)
- slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
+ slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
// if the error of the tree is not that bad
- if (percent_error < best_tree_stats.getCross_validated_cost()) {
+ if (percent_error < best_tree_stats.getTest_cost()) {
best_tree_stats = stats;
if (slog.debug() !=null)
slog.debug().log(":sequence_trees:best node updated \n");
}
-
+ // TODO update alpha_proc by increasing the min_alpha the_alpha += 10;
+ alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
iterate_trees(i+1);
} else {
//TODO update.setStopTree();
return;
}
+ } else {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:no candidate node is found ???? \n");
}
}
@@ -391,11 +306,17 @@
int k = numExtraMisClassIfPrun(my_node);
int num_leaves = my_node.getNumLeaves();
- double alpha = ((double)k)/((double)focus_tree.FACTS_READ * (num_leaves-1));
+ if (k==0) {
+ if (slog.debug() !=null)
+ slog.debug().log(":search_alphas:k == 0\n" );
+
+ }
+ double alpha = ((double)k)/((double)procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/ * (num_leaves-1));
if (slog.debug() !=null)
- slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+the_alpha+ " k "+k+" num_leaves "+num_leaves+" all "+ focus_tree.FACTS_READ + "\n");
+ slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) + "\n");
- the_alpha = alpha_proc.update_nodes(alpha, the_alpha, my_node, nodes);
+ //the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
+ alpha_proc.check_node(alpha, my_node, nodes);
for (Object attributeValue : my_node.getChildrenKeys()) {
TreeNode child = my_node.getChild(attributeValue);
@@ -407,7 +328,7 @@
}
public double getTheAlpha() {
- return the_alpha;
+ return alpha_proc.getAlpha();
}
}
@@ -422,17 +343,22 @@
}
public interface AlphaSelectionProc {
- public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
+ public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
+ public void init_proc(double value);
+ public double getAlpha();
}
public class AnAlphaProc implements AlphaSelectionProc{
-// private double an_alpha;
-// public AnAlphaProc(double value) {
-// an_alpha = value;
-// }
+ private double an_alpha;
+ public AnAlphaProc(double value) {
+ an_alpha = value;
+ }
+ public void init_proc(double value) {
+ // TODO ????
+ }
- public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
- if (cur_alpha == the_alpha) {
+ public double check_node(double cur_alpha,TreeNode cur_node, ArrayList<TreeNode> nodes) {
+ if (Util.epsilon(cur_alpha - an_alpha)) {
for(TreeNode parent:nodes) {
if (isChildOf(cur_node, parent))
return cur_alpha;// it is not added
@@ -443,44 +369,64 @@
return cur_alpha;
}
-// public double getAlpha() {
-// return an_alpha;
-// }
+ public double getAlpha() {
+ return an_alpha;
+ }
}
public class MinAlphaProc implements AlphaSelectionProc{
-// private double best_alpha;
-// public MinAlphaProc(double value) {
-// best_alpha = value;
-// }
+ private double sum_min_alpha, init_min;
+ private int num_minimum;
+ public MinAlphaProc(double value) {
+ init_min = value;
+ sum_min_alpha = 0;
+ num_minimum = 0;
+ }
- public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
- if (cur_alpha == the_alpha) {
+ public void init_proc(double value) {
+ init_min = value;
+ sum_min_alpha = 0;
+ num_minimum = 0;
+ }
+
+ public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
+ double average_of_min_alphas = getAlpha();
+ if (slog.debug() !=null)
+ slog.debug().log(":search_alphas:alpha "+ cur_alpha+ "/"+average_of_min_alphas+ " diff "+(cur_alpha - average_of_min_alphas)+"\n");
+
+ if (Util.epsilon(cur_alpha - average_of_min_alphas)) {
+ // check if the cur_node is a child of any node that has been added to the list before.
for(TreeNode parent:nodes) {
if (isChildOf(cur_node, parent))
- return cur_alpha;// it is not added
+ return cur_alpha;// if it is the case do not add the node
}
- // add this one to the set
+ // else add this one to the set
nodes.add(cur_node);
+ sum_min_alpha += cur_alpha;
+ num_minimum ++;
return cur_alpha;
- } else if (cur_alpha < the_alpha) {
+ } else if (cur_alpha < average_of_min_alphas) {
nodes.clear(); // can not put a new 'cause then it does not update the global one = new ArrayList<TreeNode>();
// remove the ones you found and replace with that one
//tree_sequence.get(dt_id).put(my_node), alpha
-
+ num_minimum = 1;
+ sum_min_alpha = cur_alpha;
nodes.add(cur_node);
return cur_alpha;
} else {
}
- return the_alpha;
+ return sum_min_alpha/num_minimum;
}
-// public double getAlpha() {
-// return best_alpha;
-// }
+ public double getAlpha() {
+ if (num_minimum == 0)
+ return init_min;
+ else
+ return sum_min_alpha/num_minimum;
+ }
}
public class NodeUpdate{
@@ -531,7 +477,7 @@
private int iteration_id;
private int num_terminal_nodes;
- private double cross_validated_cost;
+ private double test_cost;
private double resubstitution_cost;
private double cost_complexity;
private double alpha;
@@ -542,7 +488,7 @@
// to set an node update with the worst cross validated error
public TreeStats(double error) {
iteration_id = 0;
- cross_validated_cost = error;
+ test_cost = error;
}
public void iteration_id(int i) {
@@ -556,12 +502,12 @@
this.num_terminal_nodes = num_terminal_nodes;
}
- public double getCross_validated_cost() {
- return cross_validated_cost;
+ public double getTest_cost() {
+ return test_cost;
}
- public void setCross_validated_cost(double cross_validated_cost) {
- this.cross_validated_cost = cross_validated_cost;
+ public void setTest_cost(double valid_cost) {
+ this.test_cost = valid_cost;
}
public double getResubstitution_cost() {
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -27,6 +27,8 @@
private int label_size;
private int leaves;
+ private int depth;
+
public TreeNode(Domain domain) {
this.father = null;
this.domain = domain;
@@ -34,6 +36,14 @@
}
+ public void setDepth(int d) {
+ depth = d;
+ }
+
+ public int getDepth() {
+ return depth;
+ }
+
public double getRank() {
return rank;
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -52,8 +52,11 @@
}
int N = class_instances.getSize();
- int NUM_DATA = (int)(TREE_SIZE_RATIO * N);
- _trainer.setDataSizePerTree(NUM_DATA);
+ int NUM_DATA = (int)(TREE_SIZE_RATIO * N); // TREE_SIZE_RATIO = 1.0, all data is used to train the trees again again
+ _trainer.setTrainingDataSizePerTree(NUM_DATA);
+
+ /* all data fed to each tree, the same data?? */
+ _trainer.setTrainingDataSize(NUM_DATA); // TODO????
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -48,9 +48,13 @@
}
int N = class_instances.getSize();
+ //_trainer.setTrainingDataSize(N); not only N data is fed.
+
int K = _trainer.getTargetDomain().getCategoryCount();
int M = (int)(TREE_SIZE_RATIO * N);
- _trainer.setDataSizePerTree(M);
+ _trainer.setTrainingDataSizePerTree(M);
+ /* M data fed to each tree, there are FOREST_SIZE trees*/
+ _trainer.setTrainingDataSize(M * FOREST_SIZE);
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
@@ -61,7 +65,6 @@
for (int index_j=0; index_j<K; index_j++) {
Instance inst_i = class_instances.getInstance(index_i);
-
Object instance_target = inst_i.getAttrValue(_trainer.getTargetDomain().getFReferenceName());
Object instance_target_category = _trainer.getTargetDomain().getCategoryOf(instance_target);
Object target_category= _trainer.getTargetDomain().getCategory(index_j);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -1,5 +1,6 @@
package org.drools.learner.builder;
+import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
@@ -11,21 +12,31 @@
import org.drools.learner.eval.Heuristic;
import org.drools.learner.eval.InformationContainer;
import org.drools.learner.eval.InstDistribution;
+import org.drools.learner.eval.StoppingCriterion;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.Util;
public class C45Learner extends Learner{
private AttributeChooser chooser;
+ private ArrayList<StoppingCriterion> criteria;
public C45Learner(Heuristic hf) {
super();
super.setDomainAlgo(DomainAlgo.QUANTITATIVE);
chooser = new AttributeChooser(hf);
+ criteria = null;
}
- protected TreeNode train(DecisionTree dt, InstDistribution data_stats) {//List<Instance> data) {
+ public C45Learner(Heuristic hf, ArrayList<StoppingCriterion> _criteria) {
+ super();
+ super.setDomainAlgo(DomainAlgo.QUANTITATIVE);
+ chooser = new AttributeChooser(hf);
+ criteria = _criteria;
+ }
+
+ protected TreeNode train(DecisionTree dt, InstDistribution data_stats, int depth) {//List<Instance> data) {
if (data_stats.getSum() == 0) {
throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -42,7 +53,7 @@
LeafNode classifiedNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
data_stats.get_winner_class() /*winner target category*/);
classifiedNode.setRank( (double)data_stats.getSum()/
- (double)this.getDataSize()/* total size of data fed to dt*/);
+ (double)this.getTrainingDataSize()/* total size of data fed to dt*/);
classifiedNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
classifiedNode.setNumClassification(data_stats.getSum()); //num of classified instances at the leaf node
@@ -58,7 +69,7 @@
LeafNode noAttributeLeftNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
winner);
noAttributeLeftNode.setRank((double)data_stats.getVoteFor(winner)/
- (double)this.getDataSize() /* total size of data fed to dt*/);
+ (double)this.getTrainingDataSize() /* total size of data fed to dt*/);
noAttributeLeftNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner)); //num of classified instances at the leaf node
//noAttributeLeftNode.setInfoMea(best_attr_eval.attribute_eval);
@@ -66,24 +77,42 @@
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
- dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/dt.FACTS_READ);
+ dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/getTrainingDataSize());
return noAttributeLeftNode;
}
-
InformationContainer best_attr_eval = new InformationContainer();
-
+ best_attr_eval.setStats(data_stats);
+ best_attr_eval.setDepth(depth);
+ best_attr_eval.setTotalNumData(getTrainingDataSizePerTree());
+
/* choosing the best attribute in order to branch at the current node*/
chooser.chooseAttribute(best_attr_eval, data_stats, attribute_domains);
+
+ if (criteria != null & criteria.size()>0) {
+ for (StoppingCriterion sc: criteria)
+ if (sc.stop(best_attr_eval)) {
+ Object winner = data_stats.get_winner_class();
+ LeafNode majorityNode = new LeafNode(dt.getTargetDomain(), winner);
+ majorityNode.setRank((double)data_stats.getVoteFor(winner)/
+ (double)this.getTrainingDataSize() /* total size of data fed to trainer*/);
+ majorityNode.setNumMatch(data_stats.getSum());
+ majorityNode.setNumClassification(data_stats.getVoteFor(winner));
+
+ /* we need to know how many guys cannot be classified and who these guys are */
+ data_stats.missClassifiedInstances(missclassified_data);
+ dt.setTrainingError(dt.getTrainingError() + (data_stats.getSum()-data_stats.getVoteFor(winner))/getTrainingDataSize());
+ return majorityNode;
+ }
+ }
Domain node_domain = best_attr_eval.domain;
-
if (slog.debug() != null)
slog.debug().log("\n"+Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
currentNode.setRank((double)data_stats.getSum()/
- (double)this.getDataSize() /* total size of data fed to dt*/);
+ (double)this.getTrainingDataSize() /* total size of data fed to trainer*/);
currentNode.setInfoMea(best_attr_eval.attribute_eval);
//what the highest represented class is and what proportion of items at that node actually are that class
currentNode.setLabel(data_stats.get_winner_class());
@@ -107,6 +136,7 @@
/* list of domains except the choosen one (&target domain)*/
DecisionTree child_dt = new DecisionTree(dt, node_domain);
+ child_dt.FACTS_READ = dt.FACTS_READ;
if (filtered_stats == null || filtered_stats.get(category) == null || filtered_stats.get(category).getSum() ==0) {
/* majority !!!! */
@@ -120,7 +150,7 @@
majorityNode.setFather(currentNode);
currentNode.putNode(category, majorityNode);
} else {
- TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+ TreeNode newNode = train(child_dt, filtered_stats.get(category), depth+1);//, attributeNames_copy
newNode.setFather(currentNode);
currentNode.putNode(category, newNode);
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -1,5 +1,7 @@
package org.drools.learner.builder;
+import java.util.ArrayList;
+
import org.drools.WorkingMemory;
import org.drools.learner.DecisionTree;
import org.drools.learner.DecisionTreePruner;
@@ -9,8 +11,10 @@
import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
import org.drools.learner.eval.CrossValidation;
import org.drools.learner.eval.Entropy;
+import org.drools.learner.eval.EstimatedNodeSize;
import org.drools.learner.eval.GainRatio;
import org.drools.learner.eval.Heuristic;
+import org.drools.learner.eval.StoppingCriterion;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.Util;
@@ -201,6 +205,92 @@
return learner.getTree();
}
+
+ public static DecisionTree createSingleC45E_StoppingCriteria(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ return createSingleC45_Stop(wm, obj_class, new Entropy());
+ }
+
+// public static DecisionTree createSingleC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+// return createSingleC45(wm, obj_class, new GainRatio());
+// }
+
+ protected static DecisionTree createSingleC45_Stop(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+ DataType data = Learner.DEFAULT_DATA;
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ stopping_criteria.add(new EstimatedNodeSize(0.5));
+ C45Learner learner = new C45Learner(h, stopping_criteria);
+ SingleTreeBuilder single_builder = new SingleTreeBuilder();
+
+// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+ String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+ String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+
+ /* create the memory */
+ Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+ single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+
+ SingleTreeTester tester = new SingleTreeTester(learner.getTree());
+ tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
+ //Tester.test(c45, mem.getClassInstances());
+
+ learner.getTree().setSignature(executionSignature);
+ return learner.getTree();
+ }
+
+
+ public static DecisionTree createSinglePrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ return createSinglePrunnedStopC45(wm, obj_class, new Entropy());
+ }
+ public static DecisionTree createSinglePrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ return createSinglePrunnedStopC45(wm, obj_class, new GainRatio());
+ }
+
+ protected static DecisionTree createSinglePrunnedStopC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+ DataType data = Learner.DEFAULT_DATA;
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ stopping_criteria.add(new EstimatedNodeSize(0.05));
+ C45Learner learner = new C45Learner(h, stopping_criteria);
+
+ SingleTreeBuilder single_builder = new SingleTreeBuilder();
+
+ String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+ String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+
+ /* create the memory */
+ Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+ single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+
+ CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
+ validater.validate(learner);
+
+ DecisionTreePruner pruner = new DecisionTreePruner(validater);
+ pruner.prun_to_estimate();
+
+ // you should be able to get the pruned tree
+ // prun.getMinimumCostTree()
+ // prun.getOptimumCostTree()
+
+ // test the tree
+ SingleTreeTester tester = new SingleTreeTester(learner.getTree());
+ tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
+
+ /* Once Talpha is found the tree that is finally suggested for use is that
+ * which minimises the cost-complexity using and all the data use the pruner to prun the tree
+ */
+ pruner.prun_tree(learner.getTree());
+
+
+ // test the tree again
+
+
+ //Tester.test(c45, mem.getClassInstances());
+
+ learner.getTree().setSignature(executionSignature);
+ return learner.getTree();
+ }
+
+
public static String getSignature(Class<? extends Object> obj_class, String fileName, String suffices) {
//String fileName = (dataFile == null || dataFile == "") ? this.getRuleClass().getSimpleName().toLowerCase(): dataFile;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -43,8 +43,12 @@
}
int N = class_instances.getSize();
+ // _trainer.setTrainingDataSize(N); => wrong
int tree_capacity = (int)(TREE_SIZE_RATIO * N);
- _trainer.setDataSizePerTree(tree_capacity);
+ _trainer.setTrainingDataSizePerTree(tree_capacity);
+
+ /* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
+ _trainer.setTrainingDataSize(tree_capacity * FOREST_SIZE);
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,7 +22,7 @@
chooser = new AttributeChooser(hf);
}
- protected TreeNode train(DecisionTree dt, InstDistribution data_stats) {//List<Instance> data) {
+ protected TreeNode train(DecisionTree dt, InstDistribution data_stats, int depth) {//List<Instance> data) {
if (data_stats.getSum() == 0) {
throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -39,7 +39,7 @@
LeafNode classifiedNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
data_stats.get_winner_class() /*winner target category*/);
classifiedNode.setRank( (double)data_stats.getSum()/
- (double)this.getDataSize()/* total size of data fed to dt*/);
+ (double)this.getTrainingDataSize()/* total size of data fed to dt*/);
classifiedNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
classifiedNode.setNumClassification(data_stats.getSum()); //num of classified instances at the leaf node
return classifiedNode;
@@ -53,7 +53,7 @@
LeafNode noAttributeLeftNode = new LeafNode(dt.getTargetDomain() /* target domain*/,
winner);
noAttributeLeftNode.setRank((double)data_stats.getVoteFor(winner)/
- (double)this.getDataSize() /* total size of data fed to dt*/);
+ (double)this.getTrainingDataSize() /* total size of data fed to dt*/);
noAttributeLeftNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner)); //num of classified instances at the leaf node
return noAttributeLeftNode;
@@ -89,7 +89,7 @@
majorityNode.setNumClassification(0);
currentNode.putNode(category, majorityNode);
} else {
- TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+ TreeNode newNode = train(child_dt, filtered_stats.get(category), depth+1);//, attributeNames_copy
currentNode.putNode(category, newNode);
}
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,7 +22,7 @@
public static enum DataType {PRIMITIVE, STRUCTURED, COLLECTION}
public static DataType DEFAULT_DATA = DataType.PRIMITIVE;
- private int data_size;
+ private int data_size, data_size_per_tree;
private DecisionTree best_tree;
private InstanceList input_data;
protected HashSet<Instance> missclassified_data;
@@ -31,10 +31,11 @@
private DomainAlgo algorithm;
- protected abstract TreeNode train(DecisionTree dt, InstDistribution data_stats);
+ protected abstract TreeNode train(DecisionTree dt, InstDistribution data_stats, int depth);
public Learner() {
this.data_size = 0;
+ this.data_size_per_tree = 0;
}
@@ -47,9 +48,11 @@
InstDistribution stats_by_class = new InstDistribution(dt.getTargetDomain());
stats_by_class.calculateDistribution(working_instances.getInstances());
+
+
dt.FACTS_READ += working_instances.getSize();
- TreeNode root = train(dt, stats_by_class);
+ TreeNode root = train(dt, stats_by_class, 0);
dt.setRoot(root);
//flog.debug("Result tree\n" + dt);
return dt;
@@ -76,7 +79,7 @@
stats_by_class.calculateDistribution(working_instances.getInstances());
dt.FACTS_READ += working_instances.getSize();
- TreeNode root = train(dt, stats_by_class);
+ TreeNode root = train(dt, stats_by_class, 0);
dt.setRoot(root);
//flog.debug("Result tree\n" + dt);
}
@@ -84,13 +87,20 @@
}
- public void setDataSizePerTree(int num) {
- this.data_size = num;
+ public void setTrainingDataSizePerTree(int num) {
+ this.data_size_per_tree = num;
missclassified_data = new HashSet<Instance>();
}
- public int getDataSize() {
+ public int getTrainingDataSizePerTree() {
+ return this.data_size_per_tree;
+ }
+
+ public void setTrainingDataSize(int num) {
+ this.data_size = num;
+ }
+ public int getTrainingDataSize() {
return this.data_size;
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -39,8 +39,8 @@
System.exit(0);
// TODO put the feature not supported exception || implement it
}
-
- _trainer.setDataSizePerTree(class_instances.getSize());
+ _trainer.setTrainingDataSize(class_instances.getSize());
+ _trainer.setTrainingDataSizePerTree(class_instances.getSize());
one_tree = _trainer.train_tree(class_instances);
_trainer.setBestTree(one_tree);
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -106,10 +106,11 @@
dt.calc_num_node_leaves(dt.getRoot());
if (slog.error() !=null)
- slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + error +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+ slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+ /* moving averages */
validation_error_estimate += ((double)error/(double) fold_size)/(double)k_fold;
- training_error_estimate += ((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
+ training_error_estimate += ((double)dt.getTrainingError())/(double)k_fold;//((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
num_leaves_estimate += (double)dt.getRoot().getNumLeaves()/(double)k_fold;
@@ -190,6 +191,10 @@
//
// }
+ public int getTrainingDataSize(int i) {
+ return num_instances-getFoldSize(i);
+ }
+
public double getAlphaEstimate() {
return alpha_estimate;
}
@@ -197,7 +202,7 @@
int excess = num_instances % k_fold;
return (int) num_instances/k_fold + (i < excess? 1:0);
}
- public double getValidationErrorEstimate() {
+ public double getErrorEstimate() {
return validation_error_estimate;
}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,29 @@
+package org.drools.learner.eval;
+
+public class EstimatedNodeSize implements StoppingCriterion {
+
+ private double outlier_percentage;
+ private int estimated_sum_branch;
+ private int num_times_branched;
+
+
+ public EstimatedNodeSize(double o_p) {
+ outlier_percentage = o_p;
+ num_times_branched = 0;
+ }
+
+
+ public boolean stop(InformationContainer best_attr_eval) {
+ int d = best_attr_eval.getDepth();
+ estimated_sum_branch += best_attr_eval.domain.getCategoryCount();
+ num_times_branched ++;
+ double estimated_branch = (double)estimated_sum_branch/(double)num_times_branched;
+ // N/(b^d)
+ double estimated_size = best_attr_eval.getTotalNumData()/Math.pow(estimated_branch, d);
+ System.out.println("EstimatedNodeSize:stop: " +best_attr_eval.getNumData() + " <= " + ( Math.ceil(estimated_size*outlier_percentage)-1) +" / "+estimated_size);
+ if (best_attr_eval.getNumData() <= Math.ceil(estimated_size*outlier_percentage)-1)
+ return true;
+ else
+ return false;
+ }
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -10,7 +10,8 @@
public int getEstimatorSize();
public ArrayList<DecisionTree> getEstimators();
public ArrayList<InstanceList> getFold(int id);
+ public int getTrainingDataSize(int i);
- public double getValidationErrorEstimate();
+ public double getErrorEstimate();
public double getAlphaEstimate();
}
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,18 @@
+package org.drools.learner.eval;
+
+
+public class ImpurityDecrease implements StoppingCriterion {
+
+ private double beta = 0.1;
+
+ public ImpurityDecrease(double _beta) {
+ beta = _beta;
+ }
+ public boolean stop(InformationContainer best_attr_eval) {
+ if (best_attr_eval.attribute_eval < beta)
+ return true;
+ else
+ return false;
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -12,13 +12,44 @@
//public double gain_ratio;
public ArrayList<Instance> sorted_data;
+ private InstDistribution stats;
+ private int depth;
+ private int total_num_data;
+
public InformationContainer() {
domain = null;
attribute_eval = 0.0;
sorted_data = null;
+ depth = 0;
+ stats = null;
+ total_num_data = 0;
}
+
+ public void setStats(InstDistribution data_stats) {
+ stats = data_stats;
+ }
+
+ public void setDepth(int _depth) {
+ depth = _depth;
+ }
+
+ public int getDepth() {
+ return depth;
+ }
+ public double getNumData() {
+ return stats.getSum();
+ }
+
+ // total num of data fed to per tree
+ public void setTotalNumData(int num) {
+ total_num_data = num;
+ }
+ public int getTotalNumData() {
+ return total_num_data ;
+ }
+
// public InformationContainer(Domain _domain, double _attribute_eval, double _gain_ratio) {
// this.domain = _domain;
// this.attribute_eval = _attribute_eval;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -99,8 +99,7 @@
}
- public Hashtable<Object, InstDistribution> splitFromCategorical(
- Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
+ public Hashtable<Object, InstDistribution> splitFromCategorical(Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
if (instLists == null)
instLists = this.instantiateLists(splitDomain);
@@ -122,8 +121,7 @@
return instLists;
}
- private void splitFromQuantitative(ArrayList<Instance> data,
- QuantitativeDomain attributeDomain, Hashtable<Object, InstDistribution> instLists) {
+ private void splitFromQuantitative(ArrayList<Instance> data, QuantitativeDomain attributeDomain, Hashtable<Object, InstDistribution> instLists) {
String attributeName = attributeDomain.getFName();
String targetName = super.getClassDomain().getFReferenceName();
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,17 @@
+package org.drools.learner.eval;
+
+
+public class MaximumDepth implements StoppingCriterion {
+
+ private int limit_depth;
+ public MaximumDepth(int _depth) {
+ limit_depth = _depth;
+ }
+ public boolean stop(InformationContainer best_attr_eval) {
+ if (best_attr_eval.getDepth() <= limit_depth)
+ return false;
+ else
+ return true;
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,9 @@
+package org.drools.learner.eval;
+
+import org.drools.learner.DecisionTree;
+
+public interface StoppingCriterion {
+
+ public boolean stop(InformationContainer best_attr_eval);
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -66,6 +66,10 @@
return (double)x/(double)y;
}
+ public static boolean epsilon(double d) {
+ return Math.abs(d) <= 0.0001;
+ }
+
/* TODO make this all_fields arraylist as hashmap */
public static void getSuperFields(Class<?> clazz, ArrayList<Field> all_fields) {
if (clazz == Object.class)
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -37,7 +37,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 400;
+ DecisionTree decision_tree; int ALGO = 600;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -74,6 +74,12 @@
case 400:
decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
break;
+ case 500:
+ decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+ break;
+ case 600:
+ decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -40,7 +40,7 @@
session.insert(r);
}
- DecisionTree decision_tree; int ALGO = 400;
+ DecisionTree decision_tree; int ALGO = 600;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -77,6 +77,12 @@
case 400:
decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
break;
+ case 500:
+ decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+ break;
+ case 600:
+ decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-01 13:12:34 UTC (rev 21321)
@@ -36,7 +36,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 221;
+ DecisionTree decision_tree; int ALGO = 600;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
@@ -67,6 +67,12 @@
case 322:
decision_tree = DecisionTreeFactory.createBoostedC45G(session, obj_class);
break;
+ case 500:
+ decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+ break;
+ case 600:
+ decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ break;
// case 3:
// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
// break;
More information about the jboss-svn-commits
mailing list