[jboss-svn-commits] JBL Code SVN: r21422 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 4 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Sun Aug 10 21:52:30 EDT 2008
Author: gizil
Date: 2008-08-10 21:52:30 -0400 (Sun, 10 Aug 2008)
New Revision: 21422
Added:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
Modified:
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
Log:
pruner bug fixes
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -198,7 +198,7 @@
TreeNode child = my_node.getChild(child_key);
if (child instanceof LeafNode) {
- terminals.add((LeafNode)my_node);
+ terminals.add((LeafNode)child);
if (!anchestor_added) {
num_nonterminal_nodes ++; // TODO does this really work?
anchestors.add(my_node);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -1,10 +1,9 @@
package org.drools.learner;
import java.util.ArrayList;
-import java.util.Collections;
import org.drools.learner.builder.SingleTreeTester;
-import org.drools.learner.eval.Estimator;
+import org.drools.learner.eval.ErrorEstimate;
import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
@@ -16,15 +15,16 @@
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(DecisionTreePruner.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(DecisionTreePruner.class, SimpleLogger.DEBUG);
- private Estimator procedure;
+ private ErrorEstimate procedure;
private TreeStats best_stats;
private int num_trees_to_grow;
private double INIT_ALPHA = 0.5d;
+ private static final double EPSILON = 0.0d;//0.0000000001;
- public DecisionTreePruner(Estimator proc) {
+ public DecisionTreePruner(ErrorEstimate proc) {
procedure = proc;
num_trees_to_grow = procedure.getEstimatorSize();
//updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
@@ -43,15 +43,20 @@
* which is done at the second stage.
*/
double value_to_select = procedure.getErrorEstimate();
-// private NodeUpdate best_update;
- for (DecisionTree dt: procedure.getEstimators()) {
+
+ for (int dt_i = 0; dt_i<procedure.getEstimatorSize(); dt_i++) {
+ DecisionTree dt= procedure.getEstimator(dt_i);
+
+
// dt.getId()
- //dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
+ // dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
- MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA);
+ double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
+ MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
- search.iterate_trees(0);
+ search.init_tree(); // alpha_1 = 0.0
+ search.iterate_trees(1);
//updates.add(tree_sequence);
updates.add(search.getTreeSequence());
@@ -60,25 +65,25 @@
// sort the found candidates
//Collections.sort(updates.get(dt.getId()), arg1)
-
- int id =0;
+ int sid = 0, best_id = 0;
+ double best_error = 1.0;
System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- for (NodeUpdate nu: updates.get(dt.getId()) ){
- //TODO What to print here?
- //System.out.println(id +"\t"+ nu.getNum_terminal_nodes()+"\t"+nu.getCross_validated_cost()+"\t"+nu.getResubstitution_cost()+"\t"+nu.getAlpha()+"\n");
- id++;
+ ArrayList<TreeStats> trees = sequence_stats.get(dt.getId());
+ for (TreeStats st: trees ){
+ if (st.getCostEstimation() <= best_error) {
+ best_error = st.getCostEstimation();
+ best_id = sid;
+ }
+ System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+ st.getResubstitution_cost() + "\t "+ st.getAlpha() );
+ sid++;
}
-
- int sid =0;
- System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- for (TreeStats st: sequence_stats.get(dt.getId()) ){
- //System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
- System.out.println(sid+ "" +st.getAlpha() +" "+ st.getTest_cost());
- sid++;
-
- }
- int x =0;
+ System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getCostEstimation() + "\t "+ trees.get(best_id).getResubstitution_cost() + "\t "+ trees.get(best_id).getAlpha() );
+
+ int N = procedure.getTestDataSize(dt_i);
+ double standart_error_estimate = Math.sqrt((trees.get(best_id).getCostEstimation() * (1- trees.get(best_id).getCostEstimation())/ (double)N));
}
+
+
}
@@ -94,7 +99,8 @@
}
public void prun_tree(DecisionTree tree) {
- TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha()));
+ double epsilon = 0.0000001 * numExtraMisClassIfPrun(tree.getRoot());
+ TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha(), epsilon));
search.iterate_trees(0);
//search.getTreeSequence()// to go back
@@ -108,23 +114,30 @@
}
+ // returns the node missclassification cost
+ private int R(TreeNode t) {
+ return (int) t.getNumMatch() - t.getNumLabeled();
+ }
private int numExtraMisClassIfPrun(TreeNode my_node) {
- int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
+ int num_misclassified = R(my_node); // needs to be cast because of
+
if (slog.debug() !=null)
slog.debug().log(":numExtraMisClassIfPrun:num_misclassified "+ num_misclassified);
-
-
int kids_misclassified = 0;
for(Object key: my_node.getChildrenKeys()) {
TreeNode child = my_node.getChild(key);
- kids_misclassified += child.getMissClassified();//(int) child.getNumMatch() - child.getNumLabeled();
+ kids_misclassified += child.getMissClassified();
if (slog.debug() !=null)
slog.debug().log(" kid="+ kids_misclassified );
}
if (slog.debug() !=null)
slog.debug().log("\n");
+ if (num_misclassified < kids_misclassified) {
+ System.out.println("Problem ++++++");
+ System.exit(0);
+ }
return num_misclassified - kids_misclassified;
}
@@ -132,6 +145,7 @@
public class TreeSequenceProc {
private static final double MAX_ERROR_RATIO = 0.99;
+
private DecisionTree focus_tree;
//private double the_alpha;
private AlphaSelectionProc alpha_proc;
@@ -141,7 +155,8 @@
private TreeStats best_tree_stats;
public TreeSequenceProc(DecisionTree dt, AlphaSelectionProc cond) { //, double init_alpha
focus_tree = dt;
- //the_alpha = init_alpha;
+
+
alpha_proc = cond;
tree_sequence = new ArrayList<NodeUpdate>();
tree_sequence_stats = new ArrayList<TreeStats>();
@@ -169,6 +184,38 @@
return tree_sequence;
}
+ private void init_tree() {
+ // initialize the tree to be prunned
+ // T_1 is the smallest subtree of T_max satisfying R(T_1) = R(T_max)
+
+ ArrayList<TreeNode> last_nonterminals = focus_tree.getAnchestor_of_Leaves(focus_tree.getRoot());
+
+ // R(t) <= sum R(t_children) by the proposition 2.14, R(t) node miss-classification cost
+ // if R(t) = sum R(t_children) then prune off t_children
+
+ boolean tree_changed = false; // (k)
+ for (TreeNode t: last_nonterminals) {
+// int R_children = 0;
+// for (Object child_key: t.getChildrenKeys()) {
+// TreeNode child = t.getChild(child_key);
+// R_children += child.getMissClassified();
+// }
+// if (t.getMissClassified() == R_children) {
+ if (numExtraMisClassIfPrun(t) == 0) {
+ // prune off the candidate node
+ tree_changed = true;
+ prune_off(t, 0);
+ }
+
+ }
+ if (tree_changed) {
+ TreeStats stats = new TreeStats();
+ stats.iteration_id(0);
+ update_tree_stats(stats, 0.0d, 0); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+// tree_sequence_stats.add(stats);
+ }
+ }
+
private void iterate_trees(int i) {
if (slog.debug() !=null)
slog.debug().log(focus_tree.toString() +"\n");
@@ -199,11 +246,207 @@
// instead of getting the first node, have to process all nodes and prune all
// write a method to prune all
//TreeNode best_node = candidate_nodes.get(0);
+
+
+ int change_in_training_misclass = 0; // (k)
+ for (TreeNode candidate_node:candidate_nodes) {
+ // prune off the candidate node
+ prune_off(candidate_node, i);
+ change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+
+ }
TreeStats stats = new TreeStats();
stats.iteration_id(i);
+ update_tree_stats(stats, min_alpha, change_in_training_misclass); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:error "+ stats.getCostEstimation() +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
+
+ if (stats.getCostEstimation() < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
+ // if the error of the tree is not that bad
+
+ if (stats.getCostEstimation() < best_tree_stats.getCostEstimation()) {
+ best_tree_stats = stats;
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:best node updated \n");
+
+ }
+ // TODO update alpha_proc by increasing the min_alpha the_alpha += 10;
+ alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
+ iterate_trees(i+1);
+ } else {
+ //TODO update.setStopTree();
+ return;
+ }
+ } else {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:no candidate node is found ???? \n");
+ }
+
+ }
+
+
+ private void prune_off(TreeNode candidate_node, int i) {
+
+ LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
+ best_clone.setRank( candidate_node.getRank());
+ best_clone.setNumMatch(candidate_node.getNumMatch()); //num of matching instances to the leaf node
+ best_clone.setNumClassification(candidate_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
+
+ NodeUpdate update = new NodeUpdate(candidate_node, best_clone);
+ //TODO
+ update.iteration_id(i);
+
+ update.setDecisionTree(focus_tree);
+ //change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+ int num_leaves = candidate_node.getNumLeaves();
+
+ TreeNode father_node = candidate_node.getFather();
+ if (father_node != null) {
+ for(Object key: father_node.getChildrenKeys()) {
+ if (father_node.getChild(key).equals(candidate_node)) {
+ father_node.putNode(key, best_clone);
+ break;
+ }
+ }
+ updateLeaves(father_node, -num_leaves+1);
+ } else {
+ // this node does not have any father node it is the root node of the tree
+ focus_tree.setRoot(best_clone);
+ }
+
+ //updates.get(dt_0.getId()).add(update);
+ tree_sequence.add(update);
+
+ }
+
+ private void update_tree_stats(TreeStats stats, double computed_alpha, int change_in_training_error) {
+ ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
+ InstanceList learning_set = sets.get(0);
+ InstanceList validation_set = sets.get(1);
+
+ int num_error = 0;
+ SingleTreeTester t= new SingleTreeTester(focus_tree);
+
+ System.out.println(validation_set.getSize());
+ for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
+ Integer result = t.test(validation_set.getInstance(index_i));
+ if (result == Stats.INCORRECT) {
+ num_error ++;
+ }
+ }
+ double percent_error = Util.division(num_error, validation_set.getSize());
+// int _num_error = 0;
+// for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
+// Integer result = t.test(learning_set.getInstance(index_i));
+// if (result == Stats.INCORRECT) {
+// _num_error ++;
+// }
+// }
+// double _percent_error = Util.division(_num_error, learning_set.getSize());
+
+ int new_num_leaves = focus_tree.getRoot().getNumLeaves();
+
+ double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_error, procedure.getTrainingDataSize(focus_tree.getId()));
+ focus_tree.setTrainingError(new_resubstitution_cost);
+
+ double cost_complexity = new_resubstitution_cost + computed_alpha * (new_num_leaves);
+
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
+
+
+ stats.setAlpha(computed_alpha);
+ stats.setCostEstimation(percent_error);
+ //stats.setTrainingCost(_percent_error);
+ stats.setResubstitution_cost(new_resubstitution_cost);
+ // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
+ stats.setCost_complexity(cost_complexity);
+ stats.setNum_terminal_nodes(new_num_leaves);
+ tree_sequence_stats.add(stats);
+ }
+
+ // memory optimized
+ public void find_candidate_nodes(TreeNode my_node, ArrayList<TreeNode> nodes) {
+
+ if (my_node instanceof LeafNode) {
+
+ //leaves.add((LeafNode) my_node);
+ return;
+ } else {
+ // if you prune that one k more instances are misclassified
+
+ int k = numExtraMisClassIfPrun(my_node);
+ int num_leaves = my_node.getNumLeaves();
+
+ if (k==0) {
+ if (slog.debug() !=null)
+ slog.debug().log(":search_alphas:k == 0\n" );
+
+ }
+ double num_training_data = (double)procedure.getTrainingDataSize(focus_tree.getId());
+ double alpha = ( (double)k / num_training_data) /((double)(num_leaves-1));
+ if (slog.debug() !=null)
+ slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) + "\n");
+
+ //the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
+ alpha_proc.check_node(alpha, my_node, nodes);
+
+ for (Object attributeValue : my_node.getChildrenKeys()) {
+ TreeNode child = my_node.getChild(attributeValue);
+ find_candidate_nodes(child, nodes);
+ //nodes.pop();
+ }
+ }
+ return;
+ }
+
+ public double getTheAlpha() {
+ return alpha_proc.getAlpha();
+ }
+
+
+
+
+ private void iterate_trees_(int i) {
+ if (slog.debug() !=null)
+ slog.debug().log(focus_tree.toString() +"\n");
+ // if number of non-terminal nodes in the tree is more than 1
+ // = if there exists at least one non-terminal node different than root
+ if (focus_tree.getNumNonTerminalNodes() < 1) {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + focus_tree.getNumNonTerminalNodes() +"\n");
+ return;
+ } else if (focus_tree.getNumNonTerminalNodes() == 1 && focus_tree.getRoot().getNumLeaves()<=1) {
+ if (slog.debug() !=null)
+ slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + focus_tree.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +focus_tree.getRoot().getNumLeaves()+"\n");
+ return;
+ }
+ // for each non-leaf subtree
+ ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+ // to find all candidates with min_alpha value
+ //TreeSequenceProc search = new TreeSequenceProc(the_alpha, alpha_proc);//100000.0d, new MinAlphaProc());
+
+
+ find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
+ //double min_alpha = search.getTheAlpha();
+ double min_alpha = getTheAlpha();
+ System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+
+ if (candidate_nodes.size() >0) {
+ // The one or more subtrees with that value of will be replaced by leaves
+ // instead of getting the first node, have to process all nodes and prune all
+ // write a method to prune all
+ //TreeNode best_node = candidate_nodes.get(0);
+ TreeStats stats = new TreeStats();
+ stats.iteration_id(i);
+
int change_in_training_misclass = 0; // (k)
for (TreeNode candidate_node:candidate_nodes) {
+ // prune off the candidate node
+ //prune_off(candidate_node, i);
+ //change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+/* */
LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
best_clone.setRank( candidate_node.getRank());
best_clone.setNumMatch(candidate_node.getNumMatch()); //num of matching instances to the leaf node
@@ -213,9 +456,8 @@
update.iteration_id(i);
update.setDecisionTree(focus_tree);
- change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+
int num_leaves = candidate_node.getNumLeaves();
-
TreeNode father_node = candidate_node.getFather();
if (father_node != null) {
for(Object key: father_node.getChildrenKeys()) {
@@ -232,15 +474,17 @@
//updates.get(dt_0.getId()).add(update);
tree_sequence.add(update);
-
+/**/
}
-
- ArrayList<InstanceList> sets = procedure.getFold(focus_tree.getId());
- //InstanceList learning_set = sets.get(0);
+/**/
+ ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
+ InstanceList learning_set = sets.get(0);
InstanceList validation_set = sets.get(1);
int num_error = 0;
SingleTreeTester t= new SingleTreeTester(focus_tree);
+
+ System.out.println(validation_set.getSize());
for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
Integer result = t.test(validation_set.getInstance(index_i));
if (result == Stats.INCORRECT) {
@@ -248,10 +492,18 @@
}
}
double percent_error = Util.division(num_error, validation_set.getSize());
+// int _num_error = 0;
+// for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
+// Integer result = t.test(learning_set.getInstance(index_i));
+// if (result == Stats.INCORRECT) {
+// _num_error ++;
+// }
+// }
+// double _percent_error = Util.division(_num_error, learning_set.getSize());
int new_num_leaves = focus_tree.getRoot().getNumLeaves();
- double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/);
+ double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId()));
double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
@@ -260,20 +512,22 @@
stats.setAlpha(min_alpha);
- stats.setTest_cost(percent_error);
+ stats.setCostEstimation(percent_error);
+// stats.setTrainingCost(_percent_error);
stats.setResubstitution_cost(new_resubstitution_cost);
// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
stats.setCost_complexity(cost_complexity);
stats.setNum_terminal_nodes(new_num_leaves);
+
+/**/
tree_sequence_stats.add(stats);
-
if (slog.debug() !=null)
slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
// if the error of the tree is not that bad
- if (percent_error < best_tree_stats.getTest_cost()) {
+ if (percent_error < best_tree_stats.getCostEstimation()) {
best_tree_stats = stats;
if (slog.debug() !=null)
slog.debug().log(":sequence_trees:best node updated \n");
@@ -281,7 +535,7 @@
}
// TODO update alpha_proc by increasing the min_alpha the_alpha += 10;
alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
- iterate_trees(i+1);
+ iterate_trees_(i+1);
} else {
//TODO update.setStopTree();
return;
@@ -292,55 +546,8 @@
}
}
-
- // memory optimized
- public void find_candidate_nodes(TreeNode my_node, ArrayList<TreeNode> nodes) {
-
- if (my_node instanceof LeafNode) {
-
- //leaves.add((LeafNode) my_node);
- return;
- } else {
- // if you prune that one k more instances are misclassified
-
- int k = numExtraMisClassIfPrun(my_node);
- int num_leaves = my_node.getNumLeaves();
-
- if (k==0) {
- if (slog.debug() !=null)
- slog.debug().log(":search_alphas:k == 0\n" );
-
- }
- double alpha = ((double)k)/((double)procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/ * (num_leaves-1));
- if (slog.debug() !=null)
- slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) + "\n");
-
- //the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
- alpha_proc.check_node(alpha, my_node, nodes);
-
- for (Object attributeValue : my_node.getChildrenKeys()) {
- TreeNode child = my_node.getChild(attributeValue);
- find_candidate_nodes(child, nodes);
- //nodes.pop();
- }
- }
- return;
- }
-
- public double getTheAlpha() {
- return alpha_proc.getAlpha();
- }
}
- public boolean isChildOf(TreeNode cur_node, TreeNode found_node) {
- TreeNode parent = cur_node.getFather();
- if (parent == null)
- return false;
- if (parent.equals(found_node))
- return true;
- else
- return isChildOf(parent, found_node);
- }
public interface AlphaSelectionProc {
public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
@@ -350,15 +557,20 @@
public class AnAlphaProc implements AlphaSelectionProc{
private double an_alpha;
- public AnAlphaProc(double value) {
+ private double CART_EPSILON;
+ public AnAlphaProc(double value, double epsilon) {
an_alpha = value;
+ CART_EPSILON = epsilon;
}
+// public AnAlphaProc(double value) {
+// an_alpha = value;
+// }
public void init_proc(double value) {
// TODO ????
}
- public double check_node(double cur_alpha,TreeNode cur_node, ArrayList<TreeNode> nodes) {
- if (Util.epsilon(cur_alpha - an_alpha)) {
+ public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
+ if (Util.epsilon(cur_alpha - an_alpha, CART_EPSILON)) {
for(TreeNode parent:nodes) {
if (isChildOf(cur_node, parent))
return cur_alpha;// it is not added
@@ -377,10 +589,13 @@
public class MinAlphaProc implements AlphaSelectionProc{
private double sum_min_alpha, init_min;
private int num_minimum;
- public MinAlphaProc(double value) {
+
+ private double CART_EPSILON;
+ public MinAlphaProc(double value, double epsilon) {
init_min = value;
sum_min_alpha = 0;
num_minimum = 0;
+ CART_EPSILON = epsilon;
}
public void init_proc(double value) {
@@ -394,7 +609,7 @@
if (slog.debug() !=null)
slog.debug().log(":search_alphas:alpha "+ cur_alpha+ "/"+average_of_min_alphas+ " diff "+(cur_alpha - average_of_min_alphas)+"\n");
- if (Util.epsilon(cur_alpha - average_of_min_alphas)) {
+ if (Util.epsilon(cur_alpha - average_of_min_alphas, CART_EPSILON)) {
// check if the cur_node is a child of any node that has been added to the list before.
for(TreeNode parent:nodes) {
if (isChildOf(cur_node, parent))
@@ -428,7 +643,18 @@
return sum_min_alpha/num_minimum;
}
}
+ public boolean isChildOf(TreeNode cur_node, TreeNode found_node) {
+ TreeNode parent = cur_node.getFather();
+ if (parent == null)
+ return false;
+ if (parent.equals(found_node))
+ return true;
+ else
+ return isChildOf(parent, found_node);
+ }
+
+
public class NodeUpdate{
private boolean stopTree;
@@ -438,12 +664,6 @@
private TreeNode old_node, node_update;
private int iteration_id;
- //public TreeStats stats;
-// private int num_terminal_nodes;
-// private double cross_validated_cost;
-// private double resubstitution_cost;
-// private double cost_complexity;
-// private double alpha;
// to set an node update with the worst cross validated error
public NodeUpdate(double error) {
@@ -481,10 +701,12 @@
private double resubstitution_cost;
private double cost_complexity;
private double alpha;
+// private double training_error;
public TreeStats() {
iteration_id = 0;
}
+
// to set an node update with the worst cross validated error
public TreeStats(double error) {
iteration_id = 0;
@@ -502,11 +724,11 @@
this.num_terminal_nodes = num_terminal_nodes;
}
- public double getTest_cost() {
+ public double getCostEstimation() {
return test_cost;
}
- public void setTest_cost(double valid_cost) {
+ public void setCostEstimation(double valid_cost) {
this.test_cost = valid_cost;
}
@@ -534,6 +756,14 @@
this.alpha = alpha;
}
+// public void setTrainingCost(double _percent_error) {
+// training_error = _percent_error;
+// }
+//
+// public double getTrainingCost() {
+// return training_error;
+// }
+
}
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -12,7 +12,7 @@
import org.drools.learner.eval.Heuristic;
import org.drools.learner.eval.InformationContainer;
import org.drools.learner.eval.InstDistribution;
-import org.drools.learner.eval.StoppingCriterion;
+import org.drools.learner.eval.stopping.StoppingCriterion;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.Util;
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -11,10 +11,12 @@
import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
import org.drools.learner.eval.CrossValidation;
import org.drools.learner.eval.Entropy;
-import org.drools.learner.eval.EstimatedNodeSize;
+import org.drools.learner.eval.ErrorEstimate;
import org.drools.learner.eval.GainRatio;
import org.drools.learner.eval.Heuristic;
-import org.drools.learner.eval.StoppingCriterion;
+import org.drools.learner.eval.TestSample;
+import org.drools.learner.eval.stopping.EstimatedNodeSize;
+import org.drools.learner.eval.stopping.StoppingCriterion;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.Util;
@@ -152,60 +154,6 @@
return learner.getTree();
}
- public static DecisionTree createSinglePrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSinglePrunnedC45(wm, obj_class, new Entropy());
- }
- public static DecisionTree createSinglePrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSinglePrunnedC45(wm, obj_class, new GainRatio());
- }
-
- protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
- C45Learner learner = new C45Learner(h);
-
- SingleTreeBuilder single_builder = new SingleTreeBuilder();
-
-// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
-// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
- String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
- String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-
-
-
- /* create the memory */
- Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
- single_builder.build(mem, learner);//obj_class, target_attr, working_attr
-
- CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
- validater.validate(learner);
-
- DecisionTreePruner pruner = new DecisionTreePruner(validater);
- pruner.prun_to_estimate();
-
- // you should be able to get the pruned tree
- // prun.getMinimumCostTree()
- // prun.getOptimumCostTree()
-
- // test the tree
- SingleTreeTester tester = new SingleTreeTester(learner.getTree());
- tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
-
- /* Once Talpha is found the tree that is finally suggested for use is that
- * which minimises the cost-complexity using and all the data use the pruner to prun the tree
- */
- pruner.prun_tree(learner.getTree());
-
-
- // test the tree again
-
-
- //Tester.test(c45, mem.getClassInstances());
-
- learner.getTree().setSignature(executionSignature);
- return learner.getTree();
- }
-
-
public static DecisionTree createSingleC45E_StoppingCriteria(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
return createSingleC45_Stop(wm, obj_class, new Entropy());
}
@@ -238,22 +186,51 @@
return learner.getTree();
}
-
- public static DecisionTree createSinglePrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSinglePrunnedStopC45(wm, obj_class, new Entropy());
+ public static DecisionTree createSingleCrossPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
}
- public static DecisionTree createSinglePrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
- return createSinglePrunnedStopC45(wm, obj_class, new GainRatio());
+ public static DecisionTree createSingleCrossPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
}
- protected static DecisionTree createSinglePrunnedStopC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
- DataType data = Learner.DEFAULT_DATA;
+ public static DecisionTree createSingleTestPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+
ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new Entropy(), new TestSample(0.2d), stopping_criteria);
+ }
+ public static DecisionTree createSingleTestPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new TestSample(0.2), stopping_criteria);
+ }
+
+ public static DecisionTree createSingleCVPrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ //stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
+ }
+ public static DecisionTree createSingleCVPrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+ ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+ //stopping_criteria.add(new EstimatedNodeSize(0.05));
+ return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
+ }
+
+ protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h, ErrorEstimate validater, ArrayList<StoppingCriterion> stopping_criteria) throws FeatureNotSupported {
+ DataType data = Learner.DEFAULT_DATA;
+
C45Learner learner = new C45Learner(h, stopping_criteria);
SingleTreeBuilder single_builder = new SingleTreeBuilder();
-
+
+// String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+// String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
@@ -261,9 +238,9 @@
Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
single_builder.build(mem, learner);//obj_class, target_attr, working_attr
- CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
- validater.validate(learner);
+ validater.validate(learner, mem.getClassInstances());
+
DecisionTreePruner pruner = new DecisionTreePruner(validater);
pruner.prun_to_estimate();
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -12,7 +12,7 @@
import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
-public class CrossValidation implements Estimator{
+public class CrossValidation implements ErrorEstimate{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(CrossValidation.class, SimpleLogger.DEBUG);
@@ -33,7 +33,9 @@
private int VALID_SET_0 = 0, VALID_SET_1 = 1;
private boolean WITH_REP = false;
- public CrossValidation(int _k, InstanceList _instances) {
+
+
+ public CrossValidation(int _k) {
if (_k <=1) {
if (flog.warn() !=null)
flog.warn().log("There is 1 or less number of folds specified, i am setting "+MIN_NUM_FOLDS+" folds\n");
@@ -46,13 +48,10 @@
training_error_estimate = 0.0d;
num_leaves_estimate = 0.0d;
alpha_estimate = 0.0d;
- class_instances = _instances;
- num_instances = class_instances.getSize();
fold_indices = new int [k_fold] [2];
}
public int [] cross_set(int N) {
-
if (WITH_REP)
return Util.bag_w_rep(N, N);
else
@@ -61,7 +60,9 @@
// for small samples
- public void validate(Learner _trainer) {
+ public void validate(Learner _trainer, InstanceList _instances) {
+ class_instances = _instances;
+ num_instances = class_instances.getSize();
if (class_instances.getTargets().size()>1 ) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
if (flog.error() !=null)
@@ -69,7 +70,7 @@
System.exit(0);
// TODO put the feature not supported exception || implement it
}
-
+
crossed_set = cross_set(num_instances);
// N-fold
@@ -78,10 +79,10 @@
int i = 0;
// int[] bag;
while (i++ < k_fold ) {
- int fold_size = getFoldSize(i);
+ int fold_size = getTestDataSize(i);
if (slog.debug() !=null)
slog.debug().log("i "+(i-1)+"/"+k_fold);
- ArrayList<InstanceList> sets = getFold(i-1);
+ ArrayList<InstanceList> sets = getSets(i-1);
InstanceList learning_set = sets.get(0);
InstanceList validation_set = sets.get(1);
@@ -132,7 +133,7 @@
while (fold_index < k_fold) {
fold_indices[fold_index][VALID_SET_0] = divide_index;
- int fold_size = getFoldSize(fold_index);
+ int fold_size = getTestDataSize(fold_index);
fold_indices[fold_index][VALID_SET_1] = fold_indices[fold_index][VALID_SET_0] + fold_size -1;
System.out.println(fold_indices[fold_index][VALID_SET_0] +" - "+ fold_indices[fold_index][VALID_SET_1]);
@@ -141,11 +142,11 @@
}
}
- public ArrayList<InstanceList> getFold(int i) {
+ public ArrayList<InstanceList> getSets(int i) {
// // first part divide = 0; divide < fold_size*i
// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
// // last part divide = fold_size*(i+1); divide < N
- int valid_fold_size = getFoldSize(i);
+ int valid_fold_size = getTestDataSize(i);
InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - valid_fold_size +1);
InstanceList validation_set = new InstanceList(class_instances.getSchema(), valid_fold_size);
for (int divide_index = 0; divide_index < num_instances; divide_index++){
@@ -169,36 +170,15 @@
}
-// public ArrayList<InstanceList> getFolds() {
-//// // first part divide = 0; divide < fold_size*i
-//// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
-//// // last part divide = fold_size*(i+1); divide < N
-// int fold_size = getFoldSize();
-// InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - fold_size +1);
-// InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
-// for (int divide_index = 0; divide_index < num_instances; divide_index++){
-// if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
-// validation_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
-// } else { // learninf part
-// learning_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
-// }
-// }
-//
-// ArrayList<InstanceList> lists = new ArrayList<InstanceList>(2);
-// lists.add(learning_set);
-// lists.add(validation_set);
-// return lists;
-//
-// }
public int getTrainingDataSize(int i) {
- return num_instances-getFoldSize(i);
+ return num_instances-getTestDataSize(i);
}
public double getAlphaEstimate() {
return alpha_estimate;
}
- private int getFoldSize(int i) {
+ public int getTestDataSize(int i) {
int excess = num_instances % k_fold;
return (int) num_instances/k_fold + (i < excess? 1:0);
}
@@ -206,8 +186,8 @@
return validation_error_estimate;
}
- public ArrayList<DecisionTree> getEstimators() {
- return forest;
+ public DecisionTree getEstimator(int i) {
+ return forest.get(i);
}
public int getEstimatorSize() {
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,21 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.builder.Learner;
+
+public interface ErrorEstimate {
+
+ public void validate(Learner _trainer, InstanceList _instances);
+
+ public int getEstimatorSize();
+ public DecisionTree getEstimator(int i);
+ public ArrayList<InstanceList> getSets(int id);
+ public int getTestDataSize(int i);
+ public int getTrainingDataSize(int i);
+
+ public double getErrorEstimate();
+ public double getAlphaEstimate();
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,152 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Stats;
+import org.drools.learner.builder.Learner;
+import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
+
+public class TestSample implements ErrorEstimate{
+
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(TestSample.class, SimpleLogger.DEFAULT_LEVEL);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(TestSample.class, SimpleLogger.DEBUG);
+
+ private int TEST_SET_0 = 0, TEST_SET_1 = 0;
+
+ private InstanceList class_instances;
+ private double error_estimate, training_error_estimate, num_leaves_estimate, alpha_estimate;
+ private boolean WITH_REP = false;
+ private int num_instances;
+ private int[] crossed_set;
+ private double test_ratio;
+
+ private DecisionTree dt;
+
+ public TestSample(double _ratio) {
+ test_ratio = _ratio;
+ error_estimate = 0.0d;
+ training_error_estimate = 0.0d;
+ num_leaves_estimate = 0.0d;
+ alpha_estimate = 0.0d;
+
+ }
+
+ public int [] cross_set(int N) {
+ if (WITH_REP )
+ return Util.bag_w_rep(N, N);
+ else
+ return Util.bag_wo_rep(N, N);
+ }
+ public void validate(Learner _trainer, InstanceList _instances) {
+ class_instances = _instances;
+ num_instances = class_instances.getSize();
+ if (class_instances.getTargets().size()>1 ) {
+ //throw new FeatureNotSupported("There is more than 1 target candidates");
+ if (flog.error() !=null)
+ flog.error().log("There is more than 1 target candidates\n");
+ System.exit(0);
+ // TODO put the feature not supported exception || implement it
+ }
+
+ TEST_SET_1 = (int)(test_ratio * num_instances)+1;
+ System.out.println(test_ratio +"*"+ num_instances+" "+TEST_SET_1);
+
+ crossed_set = cross_set(num_instances);
+
+ ArrayList<InstanceList> sets = getSets(0);
+ InstanceList learning_set = sets.get(0);
+ InstanceList test_set = sets.get(1);
+
+ dt = _trainer.train_tree(learning_set);
+ dt.setID(0);
+
+
+ int error = 0;
+ SingleTreeTester t= new SingleTreeTester(dt);
+ if (slog.debug() !=null)
+ slog.debug().log("validation fold_size " +test_set.getSize() + "\n");
+ for (int index_i = 0; index_i < test_set.getSize(); index_i++) {
+
+ if (slog.warn() !=null)
+ slog.warn().log(" validation index_i " +index_i + (index_i ==test_set.getSize() -1?"\n":""));
+ Integer result = t.test(test_set.getInstance(index_i));
+ if (result == Stats.INCORRECT) {
+ error ++;
+ }
+ }
+ dt.setValidationError(Util.division(error, test_set.getSize()));
+ dt.calc_num_node_leaves(dt.getRoot());
+
+ if (slog.error() !=null)
+ slog.error().log("The estimate of : "+(0)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+
+ /* moving averages */
+ error_estimate = dt.getValidationError();
+ training_error_estimate = (double)dt.getTrainingError();
+ num_leaves_estimate = (double)dt.getRoot().getNumLeaves();
+
+
+// if (slog.stat() !=null)
+// slog.stat().stat("."+ (i == k_fold?"\n":""));
+ alpha_estimate = (error_estimate - training_error_estimate) /num_leaves_estimate;
+ if (slog.stat() !=null)
+ slog.stat().log(" The estimates: training=" +training_error_estimate +" valid=" + error_estimate +" num_leaves=" + num_leaves_estimate+ " the alpha"+ alpha_estimate+"\n");
+ // TODO how to compute a best tree from the forest
+ }
+
+ public ArrayList<InstanceList> getSets(int i) {
+// // first part divide = 0; divide < fold_size*i
+// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
+// // last part divide = fold_size*(i+1); divide < N
+ InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - TEST_SET_1 +1);
+ InstanceList validation_set = new InstanceList(class_instances.getSchema(), TEST_SET_1);
+ for (int divide_index = 0; divide_index < num_instances; divide_index++){
+
+ if (slog.info() !=null)
+ slog.info().log("index " +divide_index+ " fold_size" + TEST_SET_1 + " = from "+TEST_SET_0+" to "+TEST_SET_1+" num_instances "+num_instances+ "\n");
+ if (divide_index >= TEST_SET_0 && divide_index <= TEST_SET_1) { // validation
+ // validation [fold_size*i, fold_size*(i+1))
+ if (slog.info() !=null)
+ slog.info().log("validation one " +divide_index+ "\n");
+ validation_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+ } else { // learninf part
+ learning_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+ }
+ }
+
+ ArrayList<InstanceList> lists = new ArrayList<InstanceList>(2);
+ lists.add(learning_set);
+ lists.add(validation_set);
+ return lists;
+
+ }
+
+ public double getAlphaEstimate() {
+ return alpha_estimate;
+ }
+
+ public double getErrorEstimate() {
+ return error_estimate;
+ }
+
+ public int getEstimatorSize() {
+ return 1;
+ }
+
+ public DecisionTree getEstimator(int i) {
+ return dt;
+ }
+
+ public int getTrainingDataSize(int i) {
+ return num_instances - TEST_SET_1;
+ }
+ public int getTestDataSize(int i) {
+ return TEST_SET_1;
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,31 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+public class EstimatedNodeSize implements StoppingCriterion {
+
+ private double outlier_percentage;
+ private int estimated_sum_branch;
+ private int num_times_branched;
+
+
+ public EstimatedNodeSize(double o_p) {
+ outlier_percentage = o_p;
+ num_times_branched = 0;
+ }
+
+
+ public boolean stop(InformationContainer best_attr_eval) {
+ int d = best_attr_eval.getDepth();
+ estimated_sum_branch += best_attr_eval.domain.getCategoryCount();
+ num_times_branched ++;
+ double estimated_branch = (double)estimated_sum_branch/(double)num_times_branched;
+ // N/(b^d)
+ double estimated_size = best_attr_eval.getTotalNumData()/Math.pow(estimated_branch, d);
+ System.out.println("EstimatedNodeSize:stop: " +best_attr_eval.getNumData() + " <= " + ( Math.ceil(estimated_size*outlier_percentage)-1) +" / "+estimated_size);
+ if (best_attr_eval.getNumData() <= Math.ceil(estimated_size*outlier_percentage)-1)
+ return true;
+ else
+ return false;
+ }
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,20 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+
+public class ImpurityDecrease implements StoppingCriterion {
+
+ private double beta = 0.1;
+
+ public ImpurityDecrease(double _beta) {
+ beta = _beta;
+ }
+ public boolean stop(InformationContainer best_attr_eval) {
+ if (best_attr_eval.attribute_eval < beta)
+ return true;
+ else
+ return false;
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,19 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+
+public class MaximumDepth implements StoppingCriterion {
+
+ private int limit_depth;
+ public MaximumDepth(int _depth) {
+ limit_depth = _depth;
+ }
+ public boolean stop(InformationContainer best_attr_eval) {
+ if (best_attr_eval.getDepth() <= limit_depth)
+ return false;
+ else
+ return true;
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,9 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+public interface StoppingCriterion {
+
+ public boolean stop(InformationContainer best_attr_eval);
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -66,8 +66,8 @@
return (double)x/(double)y;
}
- public static boolean epsilon(double d) {
- return Math.abs(d) <= 0.0001;
+ public static boolean epsilon(double d, double epsilon) {
+ return Math.abs(d) <= epsilon; //0.000001;
}
/* TODO make this all_fields arraylist as hashmap */
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -72,14 +72,17 @@
// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
// break;
case 400:
- decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
break;
case 500:
decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
break;
case 600:
- decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
break;
+ case 601:
+ decision_tree = DecisionTreeFactory.createSingleTestPrunnedStopC45E(session, obj_class);
+ break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -75,13 +75,13 @@
// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
// break;
case 400:
- decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
break;
case 500:
decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
break;
case 600:
- decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
break;
default:
decision_tree = DecisionTreeFactory.createSingleID3E(session, obj_class);
Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-08-11 01:52:30 UTC (rev 21422)
@@ -71,7 +71,7 @@
decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
break;
case 600:
- decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+ decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
break;
// case 3:
// decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
More information about the jboss-svn-commits
mailing list