[jboss-svn-commits] JBL Code SVN: r21006 - in labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner: builder and 1 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Fri Jul 11 11:06:30 EDT 2008
Author: gizil
Date: 2008-07-11 11:06:30 -0400 (Fri, 11 Jul 2008)
New Revision: 21006
Modified:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
Log:
updates in the pruner
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -30,7 +30,7 @@
private String execution_signature;
public long FACTS_READ = 0;
- private int validation_error;
+ private int validation_error, training_error;
public DecisionTree(Schema inst_schema, String _target) {
this.obj_schema = inst_schema; //inst_schema.getObjectClass();
@@ -119,6 +119,32 @@
public int getValidationError() {
return validation_error;
}
+
+ public void setTrainingError(int error) {
+ // TODO Auto-generated method stub
+ training_error = error;
+ }
+ public int getTrainingError() {
+ // TODO Auto-generated method stub
+ return training_error;
+ }
+
+ public int calc_numleaves(TreeNode my_node) {
+ if (my_node instanceof LeafNode) {
+
+ return 1;
+ }
+ int leaves = 0;
+ for (Object child_key: my_node.getChildrenKeys()) {
+ /* split the last two class at the same time */
+
+ TreeNode child = my_node.getChild(child_key);
+ leaves += calc_numleaves(child);
+
+ }
+ my_node.setNumLeaves(leaves);
+ return leaves;
+ }
public void setSignature(String executionSignature) {
execution_signature = executionSignature;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -2,6 +2,7 @@
import java.util.ArrayList;
+import org.drools.learner.builder.SingleTreeTester;
import org.drools.learner.eval.Estimator;
@@ -12,32 +13,34 @@
private ArrayList<ArrayList<NodeUpdate>> updates;
private int num_trees_to_grow;
+
+ private NodeUpdate best_update;
public DecisionTreePruner(Estimator proc) {
procedure = proc;
num_trees_to_grow = procedure.getEstimatorSize();
updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
+
+ best_update = new NodeUpdate(10000000.0d);
}
- public void global_prun() {
-
+ public void prun_to_estimate() {
for (DecisionTree dt: procedure.getEstimators()) {
// dt.getId()
updates.add(new ArrayList<NodeUpdate>());
+ //calc_numleaves(dt.getRoot()); // this is done in the estimator
prun(dt);
}
}
-
-
- public void prun(DecisionTree dt_0) {
+ private void prun(DecisionTree dt_0) {
- calc_numleaves(dt_0.getRoot());
-
// for each non-leaf subtree
ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
- search_alphas(dt_0, dt_0.getRoot(), candidate_nodes);
+ TreeCandidate search = new TreeCandidate();
+ search.search_alphas(dt_0, dt_0.getRoot(), candidate_nodes);
+ double min_alpha = search.getMinAlpha();
System.out.println("alpha: "+min_alpha);
if (candidate_nodes.size() >0) {
@@ -49,9 +52,12 @@
NodeUpdate update = new NodeUpdate(best_node, best_clone);
//update.set
+ update.setAlpha(min_alpha);
+ update.setDecisionTree(dt_0);
+ int k = numExtraMisClassIfPrun(best_node); // extra misclassified guys
+ int num_leaves = best_node.getNumLeaves();
+ int new_num_leaves = dt_0.getRoot().getNumLeaves() - num_leaves +1;
- updates.get(dt_0.getId()).add(update);
-
TreeNode father_node = best_node.getFather();
for(Object key: father_node.getChildrenKeys()) {
if (father_node.getChild(key).equals(best_node)) {
@@ -59,79 +65,189 @@
break;
}
}
- prun(dt_0);
+ updateLeaves(father_node, -num_leaves+1);
+
+ ArrayList<InstanceList> sets = procedure.getFold(dt_0.getId());
+ //InstanceList learning_set = sets.get(0);
+ InstanceList validation_set = sets.get(1);
+
+ int error = 0;
+ SingleTreeTester t= new SingleTreeTester(dt_0);
+ for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
+ Integer result = t.test(validation_set.getInstance(index_i));
+ if (result == Stats.INCORRECT) {
+ error ++;
+ }
+ }
+
+
+ update.setCross_validated_cost(error);
+ int new_resubstitution_cost = dt_0.getTrainingError() + k;
+ double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
+
+ update.setResubstitution_cost(new_resubstitution_cost);
+ // Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
+ update.setCost_complexity(cost_complexity);
+ update.setNum_terminal_nodes(new_num_leaves);
+
+ updates.get(dt_0.getId()).add(update);
+
+
+ if (error < procedure.getValidationErrorEstimate() * 1.3) {
+ // if the error of the tree is not that bad
+
+ if (error < best_update.getCross_validated_cost()) {
+ best_update = update;
+ }
+
+ prun(dt_0);
+ } else {
+ update.setStopTree();
+ }
}
}
- private double min_alpha = 10.0;
- // memory optimized
- private void search_alphas(DecisionTree dt, TreeNode my_node, ArrayList<TreeNode> nodes) {
+ private void updateLeaves(TreeNode my_node, int i) {
+ my_node.setNumLeaves(my_node.getNumLeaves() + i);
+ TreeNode father_node = my_node.getFather();
+ if (father_node !=null)
+ updateLeaves(father_node, i);
- if (my_node instanceof LeafNode) {
+ }
+
+
+ private int numExtraMisClassIfPrun(TreeNode my_node) {
+ int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
+
+ int kids_misclassified = 0;
+ for(Object key: my_node.getChildrenKeys()) {
+ TreeNode child = my_node.getChild(key);
+ kids_misclassified += (int) child.getNumMatch() - child.getNumLabeled();
+ }
+
+ return num_misclassified - kids_misclassified;
+ }
+
+ public class TreeCandidate {
+ private double min_alpha;
+ public TreeCandidate() {
+ min_alpha = 100000000.0d;
+ }
+
+ // memory optimized
+ private void search_alphas(DecisionTree dt, TreeNode my_node, ArrayList<TreeNode> nodes) {
- //leaves.add((LeafNode) my_node);
- return;
- } else {
- // if you prune that one k more instances are misclassified
- int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
- int k = dt.getValidationError() - num_misclassified;
- int num_leaves = my_node.getNumLeaves();
-
- double alpha = k/(dt.FACTS_READ * (num_leaves-1));
- if (alpha == min_alpha) {
- // add this one to the set
- nodes.add(my_node);
- } else if (alpha < min_alpha) {
- min_alpha = alpha;
+ if (my_node instanceof LeafNode) {
- nodes = new ArrayList<TreeNode>();
- // remove the ones you found and replace with that one
- //tree_sequence.get(dt_id).put(my_node), alpha
+ //leaves.add((LeafNode) my_node);
+ return;
+ } else {
+ // if you prune that one k more instances are misclassified
- nodes.add(my_node);
+ int k = numExtraMisClassIfPrun(my_node);
+ int num_leaves = my_node.getNumLeaves();
- } else {
+ double alpha = k/(dt.FACTS_READ * (num_leaves-1));
+ if (alpha == min_alpha) {
+ // add this one to the set
+ nodes.add(my_node);
+ } else if (alpha < min_alpha) {
+ min_alpha = alpha;
+
+ nodes = new ArrayList<TreeNode>();
+ // remove the ones you found and replace with that one
+ //tree_sequence.get(dt_id).put(my_node), alpha
+
+ nodes.add(my_node);
+
+ } else {
+
+ }
+ for (Object attributeValue : my_node.getChildrenKeys()) {
+ TreeNode child = my_node.getChild(attributeValue);
+ search_alphas(dt, child, nodes);
+ //nodes.pop();
+ }
+
}
- for (Object attributeValue : my_node.getChildrenKeys()) {
- TreeNode child = my_node.getChild(attributeValue);
- search_alphas(dt, child, nodes);
- //nodes.pop();
- }
-
+
+ return;
}
-
- return;
- }
-
- private int calc_numleaves(TreeNode my_node) {
- if (my_node instanceof LeafNode) {
-
- return 1;
+ public double getMinAlpha() {
+ return min_alpha;
}
- int leaves = 0;
- for (Object child_key: my_node.getChildrenKeys()) {
- /* split the last two class at the same time */
-
- TreeNode child = my_node.getChild(child_key);
- int leaf = calc_numleaves(child);
-
- leaves += leaf;
-
- }
- my_node.setNumLeaves(leaves);
- return leaves;
}
public class NodeUpdate{
+ private boolean stopTree;
+
+ private DecisionTree tree;
+ private int num_terminal_nodes;
+ private double cross_validated_cost;
+ private double resubstitution_cost;
+ private double cost_complexity;
+ private double alpha;
+
+ public NodeUpdate(double error) {
+ cross_validated_cost = error;
+ }
+ public void setDecisionTree(DecisionTree dt_0) {
+ tree = dt_0;
+ }
public NodeUpdate(TreeNode old_n, LeafNode new_n) {
-
+ stopTree = false;
}
+
+ public void setStopTree() {
+ stopTree = true;
+ }
+
+ public int getNum_terminal_nodes() {
+ return num_terminal_nodes;
+ }
+
+ public void setNum_terminal_nodes(int num_terminal_nodes) {
+ this.num_terminal_nodes = num_terminal_nodes;
+ }
+
+ public double getCross_validated_cost() {
+ return cross_validated_cost;
+ }
+
+ public void setCross_validated_cost(double cross_validated_cost) {
+ this.cross_validated_cost = cross_validated_cost;
+ }
+
+ public double getResubstitution_cost() {
+ return resubstitution_cost;
+ }
+
+ public void setResubstitution_cost(double resubstitution_cost) {
+ this.resubstitution_cost = resubstitution_cost;
+ }
+
+ public double getCost_complexity() {
+ return cost_complexity;
+ }
+
+ public void setCost_complexity(double cost_complexity) {
+ this.cost_complexity = cost_complexity;
+ }
+
+ public double getAlpha() {
+ return alpha;
+ }
+
+ public void setAlpha(double alpha) {
+ this.alpha = alpha;
+ }
+
}
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -153,6 +153,4 @@
return buf.toString();
}
-
-
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -66,7 +66,7 @@
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
-
+ dt.setTrainingError((int) (dt.getTrainingError() + data_stats.getSum()));
return noAttributeLeftNode;
}
@@ -115,6 +115,7 @@
majorityNode.setNumMatch(0);
majorityNode.setNumClassification(0);
//currentNode.setInfoMea(best_attr_eval.attribute_eval);
+ //dt.setTrainingError((int) (dt.getTrainingError() + data_stats.getSum()));
majorityNode.setFather(currentNode);
currentNode.putNode(category, majorityNode);
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -2,8 +2,10 @@
import org.drools.WorkingMemory;
import org.drools.learner.DecisionTree;
+import org.drools.learner.DecisionTreePruner;
import org.drools.learner.Memory;
import org.drools.learner.builder.Learner.DataType;
+import org.drools.learner.eval.CrossValidation;
import org.drools.learner.eval.Entropy;
import org.drools.learner.eval.GainRatio;
import org.drools.learner.eval.Heuristic;
@@ -133,4 +135,33 @@
learner.getTree().setSignature(executionSignature);
return learner.getTree();
}
+
+ protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+ DataType data = Learner.DEFAULT_DATA;
+ C45Learner learner = new C45Learner(h);
+
+ SingleTreeBuilder single_builder = new SingleTreeBuilder();
+
+ String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+ String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+
+ /* create the memory */
+ Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+ single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+
+ CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
+ validater.validate(learner);
+
+ DecisionTreePruner pruner = new DecisionTreePruner(validater);
+ pruner.prun_to_estimate();
+
+
+
+ SingleTreeTester tester = new SingleTreeTester(learner.getTree());
+ tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
+ //Tester.test(c45, mem.getClassInstances());
+
+ learner.getTree().setSignature(executionSignature);
+ return learner.getTree();
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -17,19 +17,39 @@
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
- private int k_fold;
- private double error_estimate;
+ private int k_fold, num_instances;
+
+ /* estimated values
+ * validation_error_estimate is used to estimate the true misclassification rate
+ * */
+ private double validation_error_estimate, training_error_estimate, num_leaves_estimate, alpha_estimate;
+ private int [] crossed_set;
private ArrayList<DecisionTree> forest;
+ private InstanceList class_instances;
private boolean WITH_REP = false;
- public CrossValidation(int _k) {
+ public CrossValidation(int _k, InstanceList _instances) {
k_fold = _k;
forest = new ArrayList<DecisionTree> (k_fold);
- error_estimate = 0.0d;
+ validation_error_estimate = 0.0d;
+ training_error_estimate = 0.0d;
+ num_leaves_estimate = 0.0d;
+ alpha_estimate = 0.0d;
+ class_instances = _instances;
+
+ num_instances = class_instances.getSize();
}
+ public int [] cross_set(int N) {
+
+ if (WITH_REP)
+ return Util.bag_w_rep(N, N);
+ else
+ return Util.bag_wo_rep(N, N);
+ }
+
// for small samples
- public void validate(InstanceList class_instances, Learner _trainer) {
+ public void validate(Learner _trainer) {
if (class_instances.getTargets().size()>1 ) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
if (flog.error() !=null)
@@ -37,32 +57,16 @@
System.exit(0);
// TODO put the feature not supported exception || implement it
}
-
- int N = class_instances.getSize();
- int[] bag;
- if (WITH_REP)
- bag = Util.bag_w_rep(N, N);
- else
- bag = Util.bag_wo_rep(N, N);
+ crossed_set = cross_set(num_instances);
- int fold_size = (int)N/k_fold;
+ int fold_size = getFoldSize();
int i = 0;
// int[] bag;
while (i++ < k_fold ) {
-// // first part divide = 0; divide < fold_size*i
-// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
-// // last part divide = fold_size*(i+1); divide < N
-
- InstanceList learning_set = new InstanceList(class_instances.getSchema(), N- fold_size);
- InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
- for (int divide_index = 0; divide_index <N; divide_index++){
- if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
- validation_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
- } else { // learninf part
- learning_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
- }
- }
+ ArrayList<InstanceList> sets = getFold(i);
+ InstanceList learning_set = sets.get(0);
+ InstanceList validation_set = sets.get(1);
DecisionTree dt = _trainer.train_tree(learning_set);
dt.setID(i);
@@ -77,20 +81,52 @@
}
}
dt.setValidationError(error);
- error_estimate += error/k_fold;
+ dt.calc_numleaves(dt.getRoot());
+ validation_error_estimate += error/k_fold;
+ training_error_estimate += dt.getTrainingError()/k_fold;
+ num_leaves_estimate += dt.getRoot().getNumLeaves()/k_fold;
if (slog.stat() !=null)
slog.stat().stat(".");
}
+ alpha_estimate = (validation_error_estimate - training_error_estimate) /num_leaves_estimate;
+
// TODO how to compute a best tree from the forest
}
- public double getErrorEstimate() {
- return error_estimate;
+ public ArrayList<InstanceList> getFold(int i) {
+// // first part divide = 0; divide < fold_size*i
+// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
+// // last part divide = fold_size*(i+1); divide < N
+ int fold_size = getFoldSize();
+ InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - fold_size +1);
+ InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
+ for (int divide_index = 0; divide_index < num_instances; divide_index++){
+ if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
+ validation_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+ } else { // learninf part
+ learning_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+ }
+ }
+
+ ArrayList<InstanceList> lists = new ArrayList<InstanceList>(2);
+ lists.add(learning_set);
+ lists.add(validation_set);
+ return lists;
+
}
+
+ private int getFoldSize() {
+ // TODO Auto-generated method stub
+ return (int) num_instances/k_fold;
+ }
+ public double getValidationErrorEstimate() {
+ return validation_error_estimate;
+ }
+
public ArrayList<DecisionTree> getEstimators() {
return forest;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-07-11 13:15:28 UTC (rev 21005)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-07-11 15:06:30 UTC (rev 21006)
@@ -3,9 +3,13 @@
import java.util.ArrayList;
import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
public interface Estimator {
public int getEstimatorSize();
public ArrayList<DecisionTree> getEstimators();
+ public ArrayList<InstanceList> getFold(int id);
+
+ public double getValidationErrorEstimate();
}
More information about the jboss-svn-commits
mailing list