[jboss-svn-commits] JBL Code SVN: r20992 - in labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner: builder and 1 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Thu Jul 10 07:21:00 EDT 2008
Author: gizil
Date: 2008-07-10 07:21:00 -0400 (Thu, 10 Jul 2008)
New Revision: 20992
Added:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
Modified:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
Log:
Decision Tree prunning procedure
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -30,6 +30,8 @@
private String execution_signature;
public long FACTS_READ = 0;
+ private int validation_error;
+
public DecisionTree(Schema inst_schema, String _target) {
this.obj_schema = inst_schema; //inst_schema.getObjectClass();
@@ -110,6 +112,13 @@
return this.getRoot().voteFor(i);
}
+
+ public void setValidationError(int error) {
+ validation_error = error;
+ }
+ public int getValidationError() {
+ return validation_error;
+ }
public void setSignature(String executionSignature) {
execution_signature = executionSignature;
}
@@ -122,6 +131,6 @@
public String toString() {
String out = "Facts scanned " + FACTS_READ + "\n";
return out + root.toString();
- }
+ }
}
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,137 @@
+package org.drools.learner;
+
+import java.util.ArrayList;
+
+import org.drools.learner.eval.Estimator;
+
+
+
+public class DecisionTreePruner {
+
+ private Estimator procedure;
+
+ private ArrayList<ArrayList<NodeUpdate>> updates;
+ private int num_trees_to_grow;
+ public DecisionTreePruner(Estimator proc) {
+ procedure = proc;
+ num_trees_to_grow = procedure.getEstimatorSize();
+ updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
+ }
+
+
+ public void global_prun() {
+
+ for (DecisionTree dt: procedure.getEstimators()) {
+ // dt.getId()
+ updates.add(new ArrayList<NodeUpdate>());
+ prun(dt);
+ }
+
+ }
+
+
+
+ public void prun(DecisionTree dt_0) {
+
+ calc_numleaves(dt_0.getRoot());
+
+ // for each non-leaf subtree
+ ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+ search_alphas(dt_0, dt_0.getRoot(), candidate_nodes);
+ System.out.println("alpha: "+min_alpha);
+
+ if (candidate_nodes.size() >0) {
+ TreeNode best_node = candidate_nodes.get(0);
+ LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
+ best_clone.setRank( best_node.getRank());
+ best_clone.setNumMatch(best_node.getNumMatch()); //num of matching instances to the leaf node
+ best_clone.setNumClassification(best_node.getNumLabeled()); //num of (correctly) classified instances at the leaf node
+
+ NodeUpdate update = new NodeUpdate(best_node, best_clone);
+ //update.set
+
+ updates.get(dt_0.getId()).add(update);
+
+ TreeNode father_node = best_node.getFather();
+ for(Object key: father_node.getChildrenKeys()) {
+ if (father_node.getChild(key).equals(best_node)) {
+ father_node.putNode(key, best_clone);
+ break;
+ }
+ }
+ prun(dt_0);
+ }
+
+ }
+
+ private double min_alpha = 10.0;
+ // memory optimized
+ private void search_alphas(DecisionTree dt, TreeNode my_node, ArrayList<TreeNode> nodes) {
+
+ if (my_node instanceof LeafNode) {
+
+ //leaves.add((LeafNode) my_node);
+ return;
+ } else {
+ // if you prune that one k more instances are misclassified
+ int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
+ int k = dt.getValidationError() - num_misclassified;
+ int num_leaves = my_node.getNumLeaves();
+
+ double alpha = k/(dt.FACTS_READ * (num_leaves-1));
+ if (alpha == min_alpha) {
+ // add this one to the set
+ nodes.add(my_node);
+ } else if (alpha < min_alpha) {
+ min_alpha = alpha;
+
+ nodes = new ArrayList<TreeNode>();
+ // remove the ones you found and replace with that one
+ //tree_sequence.get(dt_id).put(my_node), alpha
+
+ nodes.add(my_node);
+
+ } else {
+
+ }
+
+ for (Object attributeValue : my_node.getChildrenKeys()) {
+ TreeNode child = my_node.getChild(attributeValue);
+ search_alphas(dt, child, nodes);
+ //nodes.pop();
+ }
+
+ }
+
+
+ return;
+ }
+
+ private int calc_numleaves(TreeNode my_node) {
+ if (my_node instanceof LeafNode) {
+
+ return 1;
+ }
+ int leaves = 0;
+ for (Object child_key: my_node.getChildrenKeys()) {
+ /* split the last two class at the same time */
+
+ TreeNode child = my_node.getChild(child_key);
+ int leaf = calc_numleaves(child);
+
+ leaves += leaf;
+
+ }
+ my_node.setNumLeaves(leaves);
+ return leaves;
+ }
+
+ public class NodeUpdate{
+
+ public NodeUpdate(TreeNode old_n, LeafNode new_n) {
+
+ }
+
+ }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java 2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -91,6 +91,7 @@
}
return;
}
+
public int getNumPaths() {
return paths.size();
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java 2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -29,6 +29,9 @@
return this.num_intances_classified;
}
+ public int getNumLeaves() {
+ return 1;
+ }
// public Integer evaluate(Instance i) {
// String targetFName = super.getDomain().getFName();
//
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -2,6 +2,7 @@
import java.util.Collection;
import java.util.Hashtable;
+import java.util.List;
import org.drools.learner.tools.Util;
@@ -11,6 +12,7 @@
//private static final Logger flog = LoggerFactory.getFileLogger(TreeNode.class);
private Domain domain;
+ private TreeNode father;
private Hashtable<Object, TreeNode> children;
/* TODO explain
* rank:
@@ -21,8 +23,12 @@
// Number of all instances matching at that node
private double num_matching_instances;
+ private Object label;
+ private int label_size;
+ private int leaves;
public TreeNode(Domain domain) {
+ this.father = null;
this.domain = domain;
this.children = new Hashtable<Object, TreeNode>();
@@ -67,7 +73,36 @@
public void setInfoMea(double mea) {
this.infoMea = mea;
}
+ public Object getLabel() {
+ return label;
+ }
+
+ public void setLabel(Object get_winner_class) {
+ label = get_winner_class;
+ }
+ public void setNumLabeled(int supportersFor) {
+ label_size = supportersFor;
+ }
+
+ public int getNumLabeled() {
+ return label_size;
+ }
+
+ public int getNumLeaves() {
+ return leaves;
+ }
+
+ public void setNumLeaves(int leaves2) {
+ leaves = leaves2;
+ }
+
+ public void setFather(TreeNode currentNode) {
+ father = currentNode;
+ }
+ public TreeNode getFather() {
+ return father;
+ }
public Object voteFor(Instance i) {
final Object attr_value = i.getAttrValue(this.domain.getFReferenceName());
final Object category = domain.getCategoryOf(attr_value);
@@ -117,6 +152,7 @@
}
return buf.toString();
}
-
+
+
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -45,6 +45,7 @@
(double)this.getDataSize()/* total size of data fed to dt*/);
classifiedNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
classifiedNode.setNumClassification(data_stats.getSum()); //num of classified instances at the leaf node
+
//classifiedNode.setInfoMea(mea)
return classifiedNode;
}
@@ -61,6 +62,8 @@
noAttributeLeftNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner)); //num of classified instances at the leaf node
//noAttributeLeftNode.setInfoMea(best_attr_eval.attribute_eval);
+
+
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
@@ -82,8 +85,10 @@
currentNode.setRank((double)data_stats.getSum()/
(double)this.getDataSize() /* total size of data fed to dt*/);
currentNode.setInfoMea(best_attr_eval.attribute_eval);
-
-
+ //what the highest represented class is and what proportion of items at that node actually are that class
+ currentNode.setLabel(data_stats.get_winner_class());
+ currentNode.setNumLabeled(data_stats.getSupportersFor(data_stats.get_winner_class()).size());
+
Hashtable<Object, InstDistribution> filtered_stats = null;
try {
filtered_stats = data_stats.split(best_attr_eval);
@@ -110,9 +115,12 @@
majorityNode.setNumMatch(0);
majorityNode.setNumClassification(0);
//currentNode.setInfoMea(best_attr_eval.attribute_eval);
+
+ majorityNode.setFather(currentNode);
currentNode.putNode(category, majorityNode);
} else {
TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+ newNode.setFather(currentNode);
currentNode.putNode(category, newNode);
}
}
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,101 @@
+package org.drools.learner.eval;
+
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Stats;
+import org.drools.learner.builder.Learner;
+import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
+
+public class CrossValidation implements Estimator{
+
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
+
+ private int k_fold;
+ private double error_estimate;
+ private ArrayList<DecisionTree> forest;
+
+ private boolean WITH_REP = false;
+ public CrossValidation(int _k) {
+ k_fold = _k;
+ forest = new ArrayList<DecisionTree> (k_fold);
+ error_estimate = 0.0d;
+ }
+
+ // for small samples
+ public void validate(InstanceList class_instances, Learner _trainer) {
+ if (class_instances.getTargets().size()>1 ) {
+ //throw new FeatureNotSupported("There is more than 1 target candidates");
+ if (flog.error() !=null)
+ flog.error().log("There is more than 1 target candidates\n");
+ System.exit(0);
+ // TODO put the feature not supported exception || implement it
+ }
+
+ int N = class_instances.getSize();
+ int[] bag;
+ if (WITH_REP)
+ bag = Util.bag_w_rep(N, N);
+ else
+ bag = Util.bag_wo_rep(N, N);
+
+
+ int fold_size = (int)N/k_fold;
+ int i = 0;
+// int[] bag;
+ while (i++ < k_fold ) {
+// // first part divide = 0; divide < fold_size*i
+// // the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
+// // last part divide = fold_size*(i+1); divide < N
+
+ InstanceList learning_set = new InstanceList(class_instances.getSchema(), N- fold_size);
+ InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
+ for (int divide_index = 0; divide_index <N; divide_index++){
+ if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
+ validation_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
+ } else { // learninf part
+ learning_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
+ }
+ }
+
+ DecisionTree dt = _trainer.train_tree(learning_set);
+ dt.setID(i);
+ forest.add(dt);
+
+ int error = 0;
+ SingleTreeTester t= new SingleTreeTester(dt);
+ for (int index_i = 0; index_i < fold_size; index_i++) {
+ Integer result = t.test(validation_set.getInstance(index_i));
+ if (result == Stats.INCORRECT) {
+ error ++;
+ }
+ }
+ dt.setValidationError(error);
+ error_estimate += error/k_fold;
+
+
+ if (slog.stat() !=null)
+ slog.stat().stat(".");
+
+ }
+ // TODO how to compute a best tree from the forest
+ }
+
+ public double getErrorEstimate() {
+ return error_estimate;
+ }
+
+ public ArrayList<DecisionTree> getEstimators() {
+ return forest;
+ }
+
+ public int getEstimatorSize() {
+ return k_fold;
+ }
+}
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java 2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,11 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+
+public interface Estimator {
+
+ public int getEstimatorSize();
+ public ArrayList<DecisionTree> getEstimators();
+}
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
___________________________________________________________________
Name: svn:eol-style
+ native
More information about the jboss-svn-commits
mailing list