[jboss-svn-commits] JBL Code SVN: r20992 - in labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner: builder and 1 other directories.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Thu Jul 10 07:21:00 EDT 2008


Author: gizil
Date: 2008-07-10 07:21:00 -0400 (Thu, 10 Jul 2008)
New Revision: 20992

Added:
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
Modified:
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
   labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
Log:
Decision Tree prunning procedure

Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -30,6 +30,8 @@
 	private String execution_signature;
 	public long FACTS_READ = 0;
 
+	private int validation_error;
+
 	public DecisionTree(Schema inst_schema, String _target) {
 		this.obj_schema = inst_schema; //inst_schema.getObjectClass();
 
@@ -110,6 +112,13 @@
 		return this.getRoot().voteFor(i);
 	}
 	
+
+	public void setValidationError(int error) {
+		validation_error = error;
+	}	
+	public int getValidationError() {
+		return validation_error;
+	}
 	public void setSignature(String executionSignature) {
 		execution_signature = executionSignature;
 	}
@@ -122,6 +131,6 @@
 	public String toString() {
 		String out = "Facts scanned " + FACTS_READ + "\n";
 		return out + root.toString();
-	}	
+	}
 	
 }

Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,137 @@
+package org.drools.learner;
+
+import java.util.ArrayList;
+
+import org.drools.learner.eval.Estimator;
+
+
+
+public class DecisionTreePruner {
+	
+	private Estimator procedure;
+	
+	private ArrayList<ArrayList<NodeUpdate>> updates;
+	private int num_trees_to_grow;
+	public DecisionTreePruner(Estimator proc) {
+		procedure = proc;
+		num_trees_to_grow = procedure.getEstimatorSize();
+		updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
+	} 	
+	
+	
+	public void global_prun() {
+		
+		for (DecisionTree dt: procedure.getEstimators()) {
+			// dt.getId()
+			updates.add(new ArrayList<NodeUpdate>());
+			prun(dt);
+		}
+	
+	}
+	
+	
+	
+	public void prun(DecisionTree dt_0) {
+		
+		calc_numleaves(dt_0.getRoot());
+		
+		// for each non-leaf subtree
+		ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+		search_alphas(dt_0, dt_0.getRoot(), candidate_nodes);
+		System.out.println("alpha: "+min_alpha);
+		
+		if (candidate_nodes.size() >0) {
+			TreeNode best_node = candidate_nodes.get(0);
+			LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
+			best_clone.setRank(	best_node.getRank());
+			best_clone.setNumMatch(best_node.getNumMatch());						//num of matching instances to the leaf node
+			best_clone.setNumClassification(best_node.getNumLabeled());				//num of (correctly) classified instances at the leaf node
+
+			NodeUpdate update = new NodeUpdate(best_node, best_clone);
+			//update.set
+
+			updates.get(dt_0.getId()).add(update);
+
+			TreeNode father_node = best_node.getFather();
+			for(Object key: father_node.getChildrenKeys()) {
+				if (father_node.getChild(key).equals(best_node)) {
+					father_node.putNode(key, best_clone);
+					break;
+				}
+			}
+			prun(dt_0);
+		}
+		
+	}
+
+	private double min_alpha = 10.0;
+	// memory optimized
+	private void search_alphas(DecisionTree dt, TreeNode my_node, ArrayList<TreeNode> nodes) {
+		
+		if (my_node instanceof LeafNode) {
+			
+			//leaves.add((LeafNode) my_node);
+			return;
+		} else {
+			// if you prune that one k more instances are misclassified
+			int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
+			int k = dt.getValidationError() - num_misclassified;
+			int num_leaves = my_node.getNumLeaves();
+			
+			double alpha = k/(dt.FACTS_READ * (num_leaves-1));
+			if (alpha == min_alpha) {
+				// add this one to the set
+				nodes.add(my_node);
+			} else if (alpha < min_alpha) {
+				min_alpha = alpha;
+				
+				nodes = new ArrayList<TreeNode>();
+				// remove the ones you found and replace with that one
+				//tree_sequence.get(dt_id).put(my_node), alpha
+				
+				nodes.add(my_node);
+				
+			} else {
+				
+			}
+			
+			for (Object attributeValue : my_node.getChildrenKeys()) {
+				TreeNode child = my_node.getChild(attributeValue);
+				search_alphas(dt, child, nodes);
+				//nodes.pop();
+			}
+			
+		}
+		
+
+		return;
+	}
+	
+	private int calc_numleaves(TreeNode my_node) {
+		if (my_node instanceof LeafNode) {
+
+			return 1;
+		}
+		int leaves = 0;
+		for (Object child_key: my_node.getChildrenKeys()) {
+			/* split the last two class at the same time */
+			
+			TreeNode child = my_node.getChild(child_key);
+			int leaf = calc_numleaves(child);
+			
+			leaves += leaf;
+			
+		}
+		my_node.setNumLeaves(leaves);
+		return leaves;
+	}
+	
+	public class NodeUpdate{
+		
+		public NodeUpdate(TreeNode old_n, LeafNode new_n) {
+			
+		}
+		
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
___________________________________________________________________
Name: svn:eol-style
   + native

Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java	2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -91,6 +91,7 @@
 		}
 		return;
 	}
+
 	
 	public int getNumPaths() {
 		return paths.size();

Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java	2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -29,6 +29,9 @@
 		return this.num_intances_classified;
 	}
 	
+	public int getNumLeaves() {
+		return 1;
+	}
 //	public Integer evaluate(Instance i) {
 //		String targetFName = super.getDomain().getFName();
 //		

Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java	2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -2,6 +2,7 @@
 
 import java.util.Collection;
 import java.util.Hashtable;
+import java.util.List;
 
 import org.drools.learner.tools.Util;
 
@@ -11,6 +12,7 @@
 	//private static final Logger flog = LoggerFactory.getFileLogger(TreeNode.class);	
 
 	private Domain domain;
+	private TreeNode father;
 	private Hashtable<Object, TreeNode> children;
 	/* TODO explain
 	 * rank:
@@ -21,8 +23,12 @@
 	
 	// Number of all instances matching at that node
 	private double num_matching_instances;
+	private Object label;
+	private int label_size;
+	private int leaves;
 	
 	public TreeNode(Domain domain) {
+		this.father = null;
 		this.domain = domain;
 		this.children = new Hashtable<Object, TreeNode>();
 		
@@ -67,7 +73,36 @@
 	public void setInfoMea(double mea) {
 		this.infoMea = mea;
 	}
+	public Object getLabel() {
+		return label;
+	}
+
+	public void setLabel(Object get_winner_class) {
+		label = get_winner_class;
+	}
 	
+	public void setNumLabeled(int supportersFor) {
+		label_size = supportersFor;
+	}
+	
+	public int getNumLabeled() {
+		return label_size;
+	}
+	
+	public int getNumLeaves() {
+		return leaves;
+	}
+	
+	public void setNumLeaves(int leaves2) {
+		leaves = leaves2;
+	}
+	
+	public void setFather(TreeNode currentNode) {
+		father = currentNode;
+	}
+	public TreeNode getFather() {
+		return father;
+	}
 	public Object voteFor(Instance i) {
 		final Object attr_value = i.getAttrValue(this.domain.getFReferenceName());
 		final Object category = domain.getCategoryOf(attr_value);
@@ -117,6 +152,7 @@
 		}
 		return buf.toString();
 	}
-	
 
+
+
 }

Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-07-10 10:21:27 UTC (rev 20991)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -45,6 +45,7 @@
 									(double)this.getDataSize()/* total size of data fed to dt*/);
 			classifiedNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			classifiedNode.setNumClassification(data_stats.getSum());				//num of classified instances at the leaf node
+			
 			//classifiedNode.setInfoMea(mea)
 			return classifiedNode;
 		}
@@ -61,6 +62,8 @@
 			noAttributeLeftNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner));	//num of classified instances at the leaf node
 			//noAttributeLeftNode.setInfoMea(best_attr_eval.attribute_eval);
+			
+			
 			/* we need to know how many guys cannot be classified and who these guys are */
 			data_stats.missClassifiedInstances(missclassified_data);
 			
@@ -82,8 +85,10 @@
 		currentNode.setRank((double)data_stats.getSum()/
 							(double)this.getDataSize()									/* total size of data fed to dt*/);
 		currentNode.setInfoMea(best_attr_eval.attribute_eval);
-		
-		
+		//what the highest represented class is and what proportion of items at that node actually are that class
+		currentNode.setLabel(data_stats.get_winner_class());
+		currentNode.setNumLabeled(data_stats.getSupportersFor(data_stats.get_winner_class()).size());
+
 		Hashtable<Object, InstDistribution> filtered_stats = null;
 		try {
 			filtered_stats = data_stats.split(best_attr_eval);
@@ -110,9 +115,12 @@
 				majorityNode.setNumMatch(0);
 				majorityNode.setNumClassification(0);
 				//currentNode.setInfoMea(best_attr_eval.attribute_eval);
+				
+				majorityNode.setFather(currentNode);
 				currentNode.putNode(category, majorityNode);
 			} else {
 				TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+				newNode.setFather(currentNode);
 				currentNode.putNode(category, newNode);
 			}
 		}

Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,101 @@
+package org.drools.learner.eval;
+
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Stats;
+import org.drools.learner.builder.Learner;
+import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
+
+public class CrossValidation implements Estimator{
+
+	private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
+	private static SimpleLogger slog = LoggerFactory.getSysOutLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
+	
+	private int k_fold;
+	private double error_estimate;
+	private ArrayList<DecisionTree> forest;
+	
+	private boolean WITH_REP = false;
+	public CrossValidation(int _k) {
+		k_fold = _k;
+		forest = new ArrayList<DecisionTree> (k_fold);
+		error_estimate = 0.0d;
+	}
+	
+	// for small samples
+	public void validate(InstanceList class_instances, Learner _trainer) {		
+		if (class_instances.getTargets().size()>1 ) {
+			//throw new FeatureNotSupported("There is more than 1 target candidates");
+			if (flog.error() !=null)
+				flog.error().log("There is more than 1 target candidates\n");
+			System.exit(0);
+			// TODO put the feature not supported exception || implement it
+		}
+	
+		int N = class_instances.getSize();
+		int[] bag;
+		if (WITH_REP)
+			bag = Util.bag_w_rep(N, N);
+		else
+			bag = Util.bag_wo_rep(N, N);	
+
+		
+		int fold_size = (int)N/k_fold;
+		int i = 0;
+//			int[] bag;		
+		while (i++ < k_fold ) {
+//			// first part divide = 0; divide < fold_size*i
+//			// the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
+//			// last part divide = fold_size*(i+1); divide < N
+
+			InstanceList learning_set = new InstanceList(class_instances.getSchema(), N- fold_size);
+			InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
+			for (int divide_index = 0; divide_index <N; divide_index++){
+				if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
+					validation_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
+				} else { // learninf part 
+					learning_set.addAsInstance(class_instances.getInstance(bag[divide_index]));
+				}
+			}
+						
+			DecisionTree dt = _trainer.train_tree(learning_set);
+			dt.setID(i);
+			forest.add(dt);
+			
+			int error = 0;
+			SingleTreeTester t= new SingleTreeTester(dt);
+			for (int index_i = 0; index_i < fold_size; index_i++) {
+				Integer result = t.test(validation_set.getInstance(index_i));
+				if (result == Stats.INCORRECT) {
+					error ++;
+				}
+			}
+			dt.setValidationError(error);
+			error_estimate += error/k_fold;
+
+
+			if (slog.stat() !=null)
+				slog.stat().stat(".");
+
+		}
+		// TODO how to compute a best tree from the forest
+	}
+	
+	public double getErrorEstimate() {
+		return error_estimate;
+	}
+	
+	public ArrayList<DecisionTree> getEstimators() {
+		return forest;
+	}
+	
+	public int getEstimatorSize() {
+		return k_fold;
+	}
+}


Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java	2008-07-10 11:21:00 UTC (rev 20992)
@@ -0,0 +1,11 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+
+public interface Estimator {
+
+	public int getEstimatorSize();
+	public ArrayList<DecisionTree> getEstimators();
+}


Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
___________________________________________________________________
Name: svn:eol-style
   + native




More information about the jboss-svn-commits mailing list