[jboss-svn-commits] JBL Code SVN: r21422 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 4 other directories.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Sun Aug 10 21:52:30 EDT 2008


Author: gizil
Date: 2008-08-10 21:52:30 -0400 (Sun, 10 Aug 2008)
New Revision: 21422

Added:
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
Modified:
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
Log:
pruner bug fixes

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -198,7 +198,7 @@
 			
 			TreeNode child = my_node.getChild(child_key);
 			if (child instanceof LeafNode) {
-				terminals.add((LeafNode)my_node);
+				terminals.add((LeafNode)child);
 				if (!anchestor_added) {
 					num_nonterminal_nodes ++; // TODO does this really work?
 					anchestors.add(my_node);

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -1,10 +1,9 @@
 package org.drools.learner;
 
 import java.util.ArrayList;
-import java.util.Collections;
 
 import org.drools.learner.builder.SingleTreeTester;
-import org.drools.learner.eval.Estimator;
+import org.drools.learner.eval.ErrorEstimate;
 import org.drools.learner.tools.LoggerFactory;
 import org.drools.learner.tools.SimpleLogger;
 import org.drools.learner.tools.Util;
@@ -16,15 +15,16 @@
 	private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(DecisionTreePruner.class, SimpleLogger.DEFAULT_LEVEL);
 	private static SimpleLogger slog = LoggerFactory.getSysOutLogger(DecisionTreePruner.class, SimpleLogger.DEBUG);
 	
-	private Estimator procedure;
+	private ErrorEstimate procedure;
 	
 	private TreeStats best_stats;
 
 	private int num_trees_to_grow;
 	
 	private double INIT_ALPHA = 0.5d;
+	private static final double EPSILON = 0.0d;//0.0000000001;
 	
-	public DecisionTreePruner(Estimator proc) {
+	public DecisionTreePruner(ErrorEstimate proc) {
 		procedure = proc;
 		num_trees_to_grow = procedure.getEstimatorSize();
 		//updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
@@ -43,15 +43,20 @@
 		 * which is done at the second stage.
 		 */
 		double value_to_select = procedure.getErrorEstimate();
-//		private NodeUpdate best_update;
-		for (DecisionTree dt: procedure.getEstimators()) {
+
+		for (int dt_i = 0; dt_i<procedure.getEstimatorSize(); dt_i++) {
+			DecisionTree dt= procedure.getEstimator(dt_i);
+			
+			
 			// dt.getId()
-			//dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
+			// dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
 			
-			MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA);
+			double epsilon = EPSILON * numExtraMisClassIfPrun(dt.getRoot());
+			MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA, epsilon);
 			TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
 			 
-			search.iterate_trees(0);
+			search.init_tree();	// alpha_1 = 0.0 
+			search.iterate_trees(1);
 			
 			//updates.add(tree_sequence);
 			updates.add(search.getTreeSequence());
@@ -60,25 +65,25 @@
 			
 			// sort the found candidates
 			//Collections.sort(updates.get(dt.getId()), arg1)
-			
-			int id =0;
+			int sid = 0, best_id = 0;
+			double best_error = 1.0; 
 			System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
-			for (NodeUpdate nu: updates.get(dt.getId()) ){
-				//TODO What to print here?
-				//System.out.println(id +"\t"+ nu.getNum_terminal_nodes()+"\t"+nu.getCross_validated_cost()+"\t"+nu.getResubstitution_cost()+"\t"+nu.getAlpha()+"\n");
-				id++;
+			ArrayList<TreeStats> trees = sequence_stats.get(dt.getId());
+			for (TreeStats st: trees ){
+				if (st.getCostEstimation() <= best_error) {
+					best_error = st.getCostEstimation();
+					best_id = sid;
+				}
+				System.out.println(sid+ "\t " +st.getNum_terminal_nodes()+ "\t "+ st.getCostEstimation() + "\t "+  st.getResubstitution_cost() + "\t "+ st.getAlpha() );
+				sid++;				
 			}
-			
-			int sid =0;
-			System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
-			for (TreeStats st: sequence_stats.get(dt.getId()) ){
-				//System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
-				System.out.println(sid+ "" +st.getAlpha() +" "+ st.getTest_cost());
-				sid++;
-				
-			}
-			int x =0;
+			System.out.println("BEST "+best_id+ "\t " +trees.get(best_id).getNum_terminal_nodes()+ "\t "+ trees.get(best_id).getCostEstimation() + "\t "+  trees.get(best_id).getResubstitution_cost() + "\t "+ trees.get(best_id).getAlpha() );
+		
+			int N = procedure.getTestDataSize(dt_i);
+			double standart_error_estimate = Math.sqrt((trees.get(best_id).getCostEstimation() * (1- trees.get(best_id).getCostEstimation())/ (double)N));
 		}
+		
+		
 	
 	
 	}
@@ -94,7 +99,8 @@
 	}
 	
 	public void prun_tree(DecisionTree tree) {
-		TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha()));
+		double epsilon = 0.0000001 * numExtraMisClassIfPrun(tree.getRoot());
+		TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha(), epsilon));
 		search.iterate_trees(0);
 		//search.getTreeSequence()// to go back
 		
@@ -108,23 +114,30 @@
 		
 	}
 	
+	// returns the node missclassification cost
+	private int R(TreeNode t) {
+		return (int) t.getNumMatch() - t.getNumLabeled();
+	}
 	
 	private int numExtraMisClassIfPrun(TreeNode my_node) {
-		int num_misclassified = (int) my_node.getNumMatch() - my_node.getNumLabeled(); // needs to be cast because of
+		int num_misclassified = R(my_node); // needs to be cast because of
+
 		if (slog.debug() !=null)
 			slog.debug().log(":numExtraMisClassIfPrun:num_misclassified "+ num_misclassified);
 		
-		
-		
 		int kids_misclassified = 0;
 		for(Object key: my_node.getChildrenKeys()) {
 			TreeNode child = my_node.getChild(key);
-			kids_misclassified += child.getMissClassified();//(int) child.getNumMatch() - child.getNumLabeled();
+			kids_misclassified += child.getMissClassified();
 			if (slog.debug() !=null)
 				slog.debug().log(" kid="+ kids_misclassified );
 		}
 		if (slog.debug() !=null)
 			slog.debug().log("\n");
+		if (num_misclassified < kids_misclassified) {
+			System.out.println("Problem ++++++");
+			System.exit(0);
+		}
 		return num_misclassified - kids_misclassified;
 	}
 	
@@ -132,6 +145,7 @@
 	public class TreeSequenceProc {
 		
 		private static final double MAX_ERROR_RATIO = 0.99;
+	
 		private DecisionTree focus_tree;
 		//private double the_alpha;
 		private AlphaSelectionProc alpha_proc;
@@ -141,7 +155,8 @@
 		private TreeStats best_tree_stats;
 		public TreeSequenceProc(DecisionTree dt, AlphaSelectionProc cond) { //, double init_alpha
 			focus_tree = dt;
-			//the_alpha = init_alpha;
+			
+			
 			alpha_proc = cond;
 			tree_sequence = new ArrayList<NodeUpdate>();
 			tree_sequence_stats = new ArrayList<TreeStats>();
@@ -169,6 +184,38 @@
 			return tree_sequence;
 		}
 
+		private void init_tree() {
+			// initialize the tree to be prunned
+			// T_1 is the smallest subtree of T_max satisfying R(T_1) = R(T_max)
+			
+			ArrayList<TreeNode> last_nonterminals = focus_tree.getAnchestor_of_Leaves(focus_tree.getRoot());
+			
+			// R(t) <= sum R(t_children) by the proposition 2.14, R(t) node miss-classification cost
+			// if R(t) = sum R(t_children) then prune off t_children
+			
+			boolean tree_changed = false; // (k)
+			for (TreeNode t: last_nonterminals) {
+//				int R_children = 0;
+//				for (Object child_key: t.getChildrenKeys()) {
+//					TreeNode child = t.getChild(child_key);
+//					R_children += child.getMissClassified();
+//				}
+//				if (t.getMissClassified() == R_children) {
+				if (numExtraMisClassIfPrun(t) == 0) {
+					// prune off the candidate node
+					tree_changed = true;
+					prune_off(t, 0);
+				} 
+				
+			}
+			if (tree_changed) {
+				TreeStats stats = new TreeStats();
+				stats.iteration_id(0);
+				update_tree_stats(stats, 0.0d, 0); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
+//			tree_sequence_stats.add(stats);
+			}
+		}
+		
 		private void iterate_trees(int i) {
 			if (slog.debug() !=null)
 				slog.debug().log(focus_tree.toString() +"\n");
@@ -199,11 +246,207 @@
 				// instead of getting the first node, have to process all nodes and prune all
 				// write a method to prune all
 				//TreeNode best_node = candidate_nodes.get(0);
+				
+				
+				int change_in_training_misclass = 0; // (k)
+				for (TreeNode candidate_node:candidate_nodes) {
+					// prune off the candidate node
+					prune_off(candidate_node, i);
+					change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+				
+				}
 				TreeStats stats = new TreeStats();
 				stats.iteration_id(i); 
+				update_tree_stats(stats, min_alpha, change_in_training_misclass); // error_estimation = stats.getCostEstimation() for the set (cross_validation or test error)
 				
+				if (slog.debug() !=null)
+					slog.debug().log(":sequence_trees:error "+ stats.getCostEstimation() +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
+
+				if (stats.getCostEstimation() < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
+					// if the error of the tree is not that bad
+					
+					if (stats.getCostEstimation() < best_tree_stats.getCostEstimation()) {
+						best_tree_stats = stats;
+						if (slog.debug() !=null)
+							slog.debug().log(":sequence_trees:best node updated \n");
+			
+					}
+					//	TODO update alpha_proc by increasing the min_alpha				the_alpha += 10;	
+					alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
+					iterate_trees(i+1);
+				} else {
+					//TODO update.setStopTree();
+					return;
+				}
+			} else {
+				if (slog.debug() !=null)
+					slog.debug().log(":sequence_trees:no candidate node is found ???? \n");
+			}
+			
+		}
+
+		
+		private void prune_off(TreeNode candidate_node, int i) {	
+			
+			LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
+			best_clone.setRank(	candidate_node.getRank());
+			best_clone.setNumMatch(candidate_node.getNumMatch());				//num of matching instances to the leaf node
+			best_clone.setNumClassification(candidate_node.getNumLabeled());	//num of (correctly) classified instances at the leaf node
+
+			NodeUpdate update = new NodeUpdate(candidate_node, best_clone);
+			//TODO 
+			update.iteration_id(i); 
+			
+			update.setDecisionTree(focus_tree);
+			//change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+			int num_leaves = candidate_node.getNumLeaves();
+
+			TreeNode father_node = candidate_node.getFather();
+			if (father_node != null) {
+				for(Object key: father_node.getChildrenKeys()) {
+					if (father_node.getChild(key).equals(candidate_node)) {
+						father_node.putNode(key, best_clone);
+						break;
+					}
+				}
+				updateLeaves(father_node, -num_leaves+1);
+			} else {
+				// this node does not have any father node it is the root node of the tree
+				focus_tree.setRoot(best_clone);
+			}
+			
+			//updates.get(dt_0.getId()).add(update);
+			tree_sequence.add(update);
+			
+		}
+		
+		private void update_tree_stats(TreeStats stats, double computed_alpha, int change_in_training_error) {
+			ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
+			InstanceList learning_set = sets.get(0);
+			InstanceList validation_set = sets.get(1);
+			
+			int num_error = 0;
+			SingleTreeTester t= new SingleTreeTester(focus_tree);
+			
+			System.out.println(validation_set.getSize());
+			for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
+				Integer result = t.test(validation_set.getInstance(index_i));
+				if (result == Stats.INCORRECT) {
+					num_error ++;
+				}
+			}
+			double percent_error = Util.division(num_error, validation_set.getSize());
+//			int _num_error = 0;
+//			for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
+//				Integer result = t.test(learning_set.getInstance(index_i));
+//				if (result == Stats.INCORRECT) {
+//					_num_error ++;
+//				}
+//			}
+//			double _percent_error = Util.division(_num_error, learning_set.getSize());
+
+			int new_num_leaves = focus_tree.getRoot().getNumLeaves();
+			
+			double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_error, procedure.getTrainingDataSize(focus_tree.getId()));
+			focus_tree.setTrainingError(new_resubstitution_cost);
+			
+			double cost_complexity = new_resubstitution_cost + computed_alpha * (new_num_leaves);
+			
+			if (slog.debug() !=null)
+				slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
+			
+			
+			stats.setAlpha(computed_alpha);
+			stats.setCostEstimation(percent_error);
+			//stats.setTrainingCost(_percent_error);
+			stats.setResubstitution_cost(new_resubstitution_cost);
+			// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
+			stats.setCost_complexity(cost_complexity);
+			stats.setNum_terminal_nodes(new_num_leaves);
+			tree_sequence_stats.add(stats);
+		}
+
+		// memory optimized
+		public void find_candidate_nodes(TreeNode my_node, ArrayList<TreeNode> nodes) {
+			
+			if (my_node instanceof LeafNode) {
+				
+				//leaves.add((LeafNode) my_node);
+				return;
+			} else {
+				// if you prune that one k more instances are misclassified
+				
+				int k = numExtraMisClassIfPrun(my_node);
+				int num_leaves = my_node.getNumLeaves();
+				
+				if (k==0) {
+					if (slog.debug() !=null)
+						slog.debug().log(":search_alphas:k == 0\n" );
+					
+				}
+				double num_training_data = (double)procedure.getTrainingDataSize(focus_tree.getId());
+				double alpha = ( (double)k / num_training_data) /((double)(num_leaves-1));
+				if (slog.debug() !=null)
+					slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) +  "\n");
+				
+				//the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
+				alpha_proc.check_node(alpha, my_node, nodes);
+				
+				for (Object attributeValue : my_node.getChildrenKeys()) {
+					TreeNode child = my_node.getChild(attributeValue);
+					find_candidate_nodes(child, nodes);
+					//nodes.pop();
+				}				
+			}
+			return;
+		}
+		
+		public double getTheAlpha() {
+			return alpha_proc.getAlpha();
+		}
+		
+		
+		
+		
+		private void iterate_trees_(int i) {
+			if (slog.debug() !=null)
+				slog.debug().log(focus_tree.toString() +"\n");
+			// if number of non-terminal nodes in the tree is more than 1 
+			// = if there exists at least one non-terminal node different than root
+			if (focus_tree.getNumNonTerminalNodes() < 1) {
+				if (slog.debug() !=null)
+					slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + focus_tree.getNumNonTerminalNodes() +"\n");
+				return; 
+			} else if (focus_tree.getNumNonTerminalNodes() == 1 && focus_tree.getRoot().getNumLeaves()<=1) {
+				if (slog.debug() !=null)
+					slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + focus_tree.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +focus_tree.getRoot().getNumLeaves()+"\n");
+				return; 
+			}
+			// for each non-leaf subtree
+			ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
+			// to find all candidates with min_alpha value
+			//TreeSequenceProc search = new TreeSequenceProc(the_alpha, alpha_proc);//100000.0d, new MinAlphaProc()); 
+			
+			
+			find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
+			//double min_alpha = search.getTheAlpha();
+			double min_alpha = getTheAlpha();
+			System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+			
+			if (candidate_nodes.size() >0) {
+				// The one or more subtrees with that value of will be replaced by leaves
+				// instead of getting the first node, have to process all nodes and prune all
+				// write a method to prune all
+				//TreeNode best_node = candidate_nodes.get(0);
+				TreeStats stats = new TreeStats();
+				stats.iteration_id(i); 
+				
 				int change_in_training_misclass = 0; // (k)
 				for (TreeNode candidate_node:candidate_nodes) {
+					// prune off the candidate node
+					//prune_off(candidate_node, i);
+					//change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+/*	*/				
 					LeafNode best_clone = new LeafNode(focus_tree.getTargetDomain(), candidate_node.getLabel());
 					best_clone.setRank(	candidate_node.getRank());
 					best_clone.setNumMatch(candidate_node.getNumMatch());				//num of matching instances to the leaf node
@@ -213,9 +456,8 @@
 					update.iteration_id(i); 
 					
 					update.setDecisionTree(focus_tree);
-					change_in_training_misclass += numExtraMisClassIfPrun(candidate_node); // extra misclassified guys
+					
 					int num_leaves = candidate_node.getNumLeaves();
-
 					TreeNode father_node = candidate_node.getFather();
 					if (father_node != null) {
 						for(Object key: father_node.getChildrenKeys()) {
@@ -232,15 +474,17 @@
 					
 					//updates.get(dt_0.getId()).add(update);
 					tree_sequence.add(update);
-					
+/**/					
 				}
-				
-				ArrayList<InstanceList> sets = procedure.getFold(focus_tree.getId());
-				//InstanceList learning_set = sets.get(0);
+/**/				
+				ArrayList<InstanceList> sets = procedure.getSets(focus_tree.getId());
+				InstanceList learning_set = sets.get(0);
 				InstanceList validation_set = sets.get(1);
 				
 				int num_error = 0;
 				SingleTreeTester t= new SingleTreeTester(focus_tree);
+				
+				System.out.println(validation_set.getSize());
 				for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
 					Integer result = t.test(validation_set.getInstance(index_i));
 					if (result == Stats.INCORRECT) {
@@ -248,10 +492,18 @@
 					}
 				}
 				double percent_error = Util.division(num_error, validation_set.getSize());
+//				int _num_error = 0;
+//				for (int index_i = 0; index_i < learning_set.getSize(); index_i++) {
+//					Integer result = t.test(learning_set.getInstance(index_i));
+//					if (result == Stats.INCORRECT) {
+//						_num_error ++;
+//					}
+//				}
+//				double _percent_error = Util.division(_num_error, learning_set.getSize());
 
 				int new_num_leaves = focus_tree.getRoot().getNumLeaves();
 				
-				double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/);
+				double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId()));
 				double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
 				
 				
@@ -260,20 +512,22 @@
 				
 				
 				stats.setAlpha(min_alpha);
-				stats.setTest_cost(percent_error);
+				stats.setCostEstimation(percent_error);
+//				stats.setTrainingCost(_percent_error);
 				stats.setResubstitution_cost(new_resubstitution_cost);
 				// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
 				stats.setCost_complexity(cost_complexity);
 				stats.setNum_terminal_nodes(new_num_leaves);
+				
+/**/
 				tree_sequence_stats.add(stats);
-				
 				if (slog.debug() !=null)
 					slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
 
 				if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
 					// if the error of the tree is not that bad
 					
-					if (percent_error < best_tree_stats.getTest_cost()) {
+					if (percent_error < best_tree_stats.getCostEstimation()) {
 						best_tree_stats = stats;
 						if (slog.debug() !=null)
 							slog.debug().log(":sequence_trees:best node updated \n");
@@ -281,7 +535,7 @@
 					}
 					//	TODO update alpha_proc by increasing the min_alpha				the_alpha += 10;	
 					alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
-					iterate_trees(i+1);
+					iterate_trees_(i+1);
 				} else {
 					//TODO update.setStopTree();
 					return;
@@ -292,55 +546,8 @@
 			}
 			
 		}
-		
-		// memory optimized
-		public void find_candidate_nodes(TreeNode my_node, ArrayList<TreeNode> nodes) {
-			
-			if (my_node instanceof LeafNode) {
-				
-				//leaves.add((LeafNode) my_node);
-				return;
-			} else {
-				// if you prune that one k more instances are misclassified
-				
-				int k = numExtraMisClassIfPrun(my_node);
-				int num_leaves = my_node.getNumLeaves();
-				
-				if (k==0) {
-					if (slog.debug() !=null)
-						slog.debug().log(":search_alphas:k == 0\n" );
-					
-				}
-				double alpha = ((double)k)/((double)procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/ * (num_leaves-1));
-				if (slog.debug() !=null)
-					slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) +  "\n");
-				
-				//the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
-				alpha_proc.check_node(alpha, my_node, nodes);
-				
-				for (Object attributeValue : my_node.getChildrenKeys()) {
-					TreeNode child = my_node.getChild(attributeValue);
-					find_candidate_nodes(child, nodes);
-					//nodes.pop();
-				}				
-			}
-			return;
-		}
-		
-		public double getTheAlpha() {
-			return alpha_proc.getAlpha();
-		}
 	}
 	
-	public boolean isChildOf(TreeNode cur_node, TreeNode found_node) {
-		TreeNode parent = cur_node.getFather();
-		if (parent == null)
-			return false;
-		if (parent.equals(found_node))
-			return true;
-		else 
-			return isChildOf(parent, found_node);
-	}
 	
 	public interface AlphaSelectionProc {
 		public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
@@ -350,15 +557,20 @@
 	
 	public class AnAlphaProc implements AlphaSelectionProc{
 		private double an_alpha;
-		public AnAlphaProc(double value) {
+		private double CART_EPSILON;
+		public AnAlphaProc(double value, double epsilon) {
 			an_alpha = value;
+			CART_EPSILON = epsilon;
 		}
+//		public AnAlphaProc(double value) {
+//			an_alpha = value;
+//		}
 		public void init_proc(double value) {
 			// TODO ????
 		}
 		
-		public double check_node(double cur_alpha,TreeNode cur_node, ArrayList<TreeNode> nodes) {
-			if (Util.epsilon(cur_alpha - an_alpha)) {
+		public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
+			if (Util.epsilon(cur_alpha - an_alpha, CART_EPSILON)) {
 				for(TreeNode parent:nodes) {
 					if (isChildOf(cur_node, parent))
 						return cur_alpha;// it is not added
@@ -377,10 +589,13 @@
 	public class MinAlphaProc implements AlphaSelectionProc{
 		private double sum_min_alpha, init_min;
 		private int num_minimum;
-		public MinAlphaProc(double value) {
+		
+		private double CART_EPSILON;
+		public MinAlphaProc(double value, double epsilon) {
 			init_min = value;
 			sum_min_alpha = 0;
 			num_minimum = 0;
+			CART_EPSILON = epsilon;
 		}
 		
 		public void init_proc(double value) {
@@ -394,7 +609,7 @@
 			if (slog.debug() !=null)
 				slog.debug().log(":search_alphas:alpha "+ cur_alpha+ "/"+average_of_min_alphas+ " diff "+(cur_alpha - average_of_min_alphas)+"\n");
 			
-			if (Util.epsilon(cur_alpha - average_of_min_alphas)) {
+			if (Util.epsilon(cur_alpha - average_of_min_alphas, CART_EPSILON)) {
 				// check if the cur_node is a child of any node that has been added to the list before.
 				for(TreeNode parent:nodes) {
 					if (isChildOf(cur_node, parent))
@@ -428,7 +643,18 @@
 				return sum_min_alpha/num_minimum;
 		}
 	}
+	public boolean isChildOf(TreeNode cur_node, TreeNode found_node) {
+		TreeNode parent = cur_node.getFather();
+		if (parent == null)
+			return false;
+		if (parent.equals(found_node))
+			return true;
+		else 
+			return isChildOf(parent, found_node);
+	}
 	
+	
+	
 	public class NodeUpdate{
 		
 		private boolean stopTree;
@@ -438,12 +664,6 @@
 		private TreeNode old_node, node_update;
 		
 		private int iteration_id;
-		//public TreeStats stats;
-//		private int num_terminal_nodes;
-//		private double cross_validated_cost;
-//		private double resubstitution_cost;
-//		private double cost_complexity;
-//		private double alpha;
 		
 		// to set an node update with the worst cross validated error
 		public NodeUpdate(double error) {
@@ -481,10 +701,12 @@
 		private double resubstitution_cost;
 		private double cost_complexity;
 		private double alpha;
+//		private double training_error;
 		
 		public TreeStats() {
 			iteration_id = 0;
 		}
+
 		// to set an node update with the worst cross validated error
 		public TreeStats(double error) {
 			iteration_id = 0;
@@ -502,11 +724,11 @@
 			this.num_terminal_nodes = num_terminal_nodes;
 		}
 
-		public double getTest_cost() {
+		public double getCostEstimation() {
 			return test_cost;
 		}
 
-		public void setTest_cost(double valid_cost) {
+		public void setCostEstimation(double valid_cost) {
 			this.test_cost = valid_cost;
 		}
 
@@ -534,6 +756,14 @@
 			this.alpha = alpha;
 		}
 		
+//		public void setTrainingCost(double _percent_error) {
+//			training_error = _percent_error;
+//		}
+//		
+//		public double getTrainingCost() {
+//			return training_error;
+//		}
+		
 	}
 	
 

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -12,7 +12,7 @@
 import org.drools.learner.eval.Heuristic;
 import org.drools.learner.eval.InformationContainer;
 import org.drools.learner.eval.InstDistribution;
-import org.drools.learner.eval.StoppingCriterion;
+import org.drools.learner.eval.stopping.StoppingCriterion;
 import org.drools.learner.tools.FeatureNotSupported;
 import org.drools.learner.tools.Util;
 

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -11,10 +11,12 @@
 import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
 import org.drools.learner.eval.CrossValidation;
 import org.drools.learner.eval.Entropy;
-import org.drools.learner.eval.EstimatedNodeSize;
+import org.drools.learner.eval.ErrorEstimate;
 import org.drools.learner.eval.GainRatio;
 import org.drools.learner.eval.Heuristic;
-import org.drools.learner.eval.StoppingCriterion;
+import org.drools.learner.eval.TestSample;
+import org.drools.learner.eval.stopping.EstimatedNodeSize;
+import org.drools.learner.eval.stopping.StoppingCriterion;
 import org.drools.learner.tools.FeatureNotSupported;
 import org.drools.learner.tools.Util;
 
@@ -152,60 +154,6 @@
 		return learner.getTree();
 	}
 	
-	public static DecisionTree createSinglePrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-		return createSinglePrunnedC45(wm, obj_class, new Entropy());
-	}
-	public static DecisionTree createSinglePrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-		return createSinglePrunnedC45(wm, obj_class, new GainRatio());
-	}
-	
-	protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
-		DataType data = Learner.DEFAULT_DATA;
-		C45Learner learner = new C45Learner(h);
-		
-		SingleTreeBuilder single_builder = new SingleTreeBuilder();
-		
-//		String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
-//		String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-		String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
-		String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
-		
-		
-		
-		/* create the memory */
-		Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
-		single_builder.build(mem, learner);//obj_class, target_attr, working_attr
-		
-		CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
-		validater.validate(learner);
-		
-		DecisionTreePruner pruner = new DecisionTreePruner(validater);
-		pruner.prun_to_estimate();
-		
-		// you should be able to get the pruned tree
-		// prun.getMinimumCostTree()
-		// prun.getOptimumCostTree()
-		
-		// test the tree
-		SingleTreeTester tester = new SingleTreeTester(learner.getTree());
-		tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
-		
-		/* Once Talpha is found the tree that is finally suggested for use is that 
-		 * which minimises the cost-complexity using and all the data use the pruner to prun the tree
-		 */
-		pruner.prun_tree(learner.getTree());
-		
-		
-		// test the tree again
-		
-		
-		//Tester.test(c45, mem.getClassInstances());
-		
-		learner.getTree().setSignature(executionSignature);
-		return learner.getTree();
-	}
-	
-	
 	public static DecisionTree createSingleC45E_StoppingCriteria(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
 		return createSingleC45_Stop(wm, obj_class, new Entropy());
 	}
@@ -238,22 +186,51 @@
 		return learner.getTree();
 	}
 	
-	
-	public static DecisionTree createSinglePrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-		return createSinglePrunnedStopC45(wm, obj_class, new Entropy());
+	public static DecisionTree createSingleCrossPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
 	}
-	public static DecisionTree createSinglePrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
-		return createSinglePrunnedStopC45(wm, obj_class, new GainRatio());
+	public static DecisionTree createSingleCrossPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
 	}
 	
-	protected static DecisionTree createSinglePrunnedStopC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
-		DataType data = Learner.DEFAULT_DATA;
+	public static DecisionTree createSingleTestPrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		
 		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
 		stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new Entropy(), new TestSample(0.2d), stopping_criteria);
+	}
+	public static DecisionTree createSingleTestPrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new TestSample(0.2), stopping_criteria);
+	}
+	
+	public static DecisionTree createSingleCVPrunnedC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		//stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new Entropy(), new CrossValidation(10), stopping_criteria);
+	}
+	public static DecisionTree createSingleCVPrunnedC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		//stopping_criteria.add(new EstimatedNodeSize(0.05));
+		return createSinglePrunnedC45(wm, obj_class, new GainRatio(), new CrossValidation(10), stopping_criteria);
+	}
+	
+	protected static DecisionTree createSinglePrunnedC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h, ErrorEstimate validater, ArrayList<StoppingCriterion> stopping_criteria) throws FeatureNotSupported {
+		DataType data = Learner.DEFAULT_DATA;
+
 		C45Learner learner = new C45Learner(h, stopping_criteria);
 		
 		SingleTreeBuilder single_builder = new SingleTreeBuilder();
-	
+		
+//		String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+//		String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
 		String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
 		String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
 		
@@ -261,9 +238,9 @@
 		Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
 		single_builder.build(mem, learner);//obj_class, target_attr, working_attr
 		
-		CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
-		validater.validate(learner);
 		
+		validater.validate(learner, mem.getClassInstances());
+		
 		DecisionTreePruner pruner = new DecisionTreePruner(validater);
 		pruner.prun_to_estimate();
 		

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -12,7 +12,7 @@
 import org.drools.learner.tools.SimpleLogger;
 import org.drools.learner.tools.Util;
 
-public class CrossValidation implements Estimator{
+public class CrossValidation implements ErrorEstimate{
 
 	private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(CrossValidation.class, SimpleLogger.DEFAULT_LEVEL);
 	private static SimpleLogger slog = LoggerFactory.getSysOutLogger(CrossValidation.class, SimpleLogger.DEBUG);
@@ -33,7 +33,9 @@
 	private int VALID_SET_0 = 0, VALID_SET_1 = 1;
 	
 	private boolean WITH_REP = false;
-	public CrossValidation(int _k, InstanceList _instances) {
+
+	
+	public CrossValidation(int _k) {
 		if (_k <=1) {
 			if (flog.warn() !=null)
 				flog.warn().log("There is 1 or less number of folds specified, i am setting "+MIN_NUM_FOLDS+" folds\n");
@@ -46,13 +48,10 @@
 		training_error_estimate = 0.0d;
 		num_leaves_estimate = 0.0d;
 		alpha_estimate = 0.0d;
-		class_instances = _instances;
 		
-		num_instances = class_instances.getSize();
 		fold_indices = new int [k_fold] [2];
 	}
 	public int [] cross_set(int N) {
-
 		if (WITH_REP)
 			return Util.bag_w_rep(N, N);
 		else
@@ -61,7 +60,9 @@
 	
 	
 	// for small samples
-	public void validate(Learner _trainer) {		
+	public void validate(Learner _trainer, InstanceList _instances) {		
+		class_instances = _instances;
+		num_instances = class_instances.getSize();
 		if (class_instances.getTargets().size()>1 ) {
 			//throw new FeatureNotSupported("There is more than 1 target candidates");
 			if (flog.error() !=null)
@@ -69,7 +70,7 @@
 			System.exit(0);
 			// TODO put the feature not supported exception || implement it
 		}
-
+	
 		crossed_set = cross_set(num_instances);
 		
 		// N-fold 
@@ -78,10 +79,10 @@
 		int i = 0;
 //			int[] bag;		
 		while (i++ < k_fold ) {
-			int fold_size = getFoldSize(i);
+			int fold_size = getTestDataSize(i);
 			if (slog.debug() !=null)
 				slog.debug().log("i "+(i-1)+"/"+k_fold);
-			ArrayList<InstanceList> sets = getFold(i-1);
+			ArrayList<InstanceList> sets = getSets(i-1);
 			InstanceList learning_set = sets.get(0);
 			InstanceList validation_set = sets.get(1);
 						
@@ -132,7 +133,7 @@
 		while (fold_index < k_fold) {
 			
 			fold_indices[fold_index][VALID_SET_0] = divide_index;
-			int fold_size = getFoldSize(fold_index); 
+			int fold_size = getTestDataSize(fold_index); 
 			fold_indices[fold_index][VALID_SET_1] = fold_indices[fold_index][VALID_SET_0] +  fold_size -1;
 			
 			System.out.println(fold_indices[fold_index][VALID_SET_0] +" - "+ fold_indices[fold_index][VALID_SET_1]);
@@ -141,11 +142,11 @@
 		}
 		
 	}
-	public ArrayList<InstanceList> getFold(int i) {
+	public ArrayList<InstanceList> getSets(int i) {
 //		// first part divide = 0; divide < fold_size*i
 //		// the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
 //		// last part divide = fold_size*(i+1); divide < N
-		int valid_fold_size = getFoldSize(i);
+		int valid_fold_size = getTestDataSize(i);
 		InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - valid_fold_size +1);
 		InstanceList validation_set = new InstanceList(class_instances.getSchema(), valid_fold_size);
 		for (int divide_index = 0; divide_index < num_instances; divide_index++){
@@ -169,36 +170,15 @@
 		
 	}
 	
-//	public ArrayList<InstanceList> getFolds() {
-////		// first part divide = 0; divide < fold_size*i
-////		// the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
-////		// last part divide = fold_size*(i+1); divide < N
-//		int fold_size = getFoldSize();
-//		InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - fold_size +1);
-//		InstanceList validation_set = new InstanceList(class_instances.getSchema(), fold_size);
-//		for (int divide_index = 0; divide_index < num_instances; divide_index++){
-//			if (divide_index >= fold_size*i && divide_index < fold_size*(i+1)-1) { // validation
-//				validation_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
-//			} else { // learninf part 
-//				learning_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
-//			}
-//		}
-//		
-//		ArrayList<InstanceList> lists = new ArrayList<InstanceList>(2);
-//		lists.add(learning_set);
-//		lists.add(validation_set);
-//		return lists;
-//		
-//	}
 	
 	public int getTrainingDataSize(int i) {
-		return num_instances-getFoldSize(i);
+		return num_instances-getTestDataSize(i);
 	}
 	
 	public double getAlphaEstimate() {
 		return alpha_estimate;
 	}
-	private int getFoldSize(int i) {
+	public int getTestDataSize(int i) {
 		int excess = num_instances % k_fold;
 		return (int) num_instances/k_fold + (i < excess? 1:0);
 	}
@@ -206,8 +186,8 @@
 		return validation_error_estimate;
 	}
 	
-	public ArrayList<DecisionTree> getEstimators() {
-		return forest;
+	public DecisionTree getEstimator(int i) {
+		return forest.get(i);
 	}
 	
 	public int getEstimatorSize() {

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,21 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.builder.Learner;
+
+public interface ErrorEstimate {
+
+	public void validate(Learner _trainer, InstanceList _instances);
+	
+	public int getEstimatorSize();
+	public DecisionTree getEstimator(int i);
+	public ArrayList<InstanceList> getSets(int id);
+	public int getTestDataSize(int i);
+	public int getTrainingDataSize(int i);
+	
+	public double getErrorEstimate();
+	public double getAlphaEstimate();
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ErrorEstimate.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,152 @@
+package org.drools.learner.eval;
+
+import java.util.ArrayList;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.InstanceList;
+import org.drools.learner.Stats;
+import org.drools.learner.builder.Learner;
+import org.drools.learner.builder.SingleTreeTester;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+import org.drools.learner.tools.Util;
+
+public class TestSample implements ErrorEstimate{
+	
+	private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(TestSample.class, SimpleLogger.DEFAULT_LEVEL);
+	private static SimpleLogger slog = LoggerFactory.getSysOutLogger(TestSample.class, SimpleLogger.DEBUG);
+	
+	private int TEST_SET_0 = 0, TEST_SET_1 = 0;
+	
+	private InstanceList class_instances;
+	private double error_estimate, training_error_estimate, num_leaves_estimate, alpha_estimate;
+	private boolean WITH_REP = false;
+	private int num_instances;
+	private int[] crossed_set;
+	private double test_ratio;
+	
+	private DecisionTree dt;
+	
+	public TestSample(double _ratio) {
+		test_ratio = _ratio;
+		error_estimate = 0.0d;
+		training_error_estimate = 0.0d;
+		num_leaves_estimate = 0.0d;
+		alpha_estimate = 0.0d;
+		
+	}
+	
+	public int [] cross_set(int N) {
+		if (WITH_REP )
+			return Util.bag_w_rep(N, N);
+		else
+			return Util.bag_wo_rep(N, N);
+	}
+	public void validate(Learner _trainer, InstanceList _instances) {
+		class_instances = _instances;
+		num_instances = class_instances.getSize();
+		if (class_instances.getTargets().size()>1 ) {
+			//throw new FeatureNotSupported("There is more than 1 target candidates");
+			if (flog.error() !=null)
+				flog.error().log("There is more than 1 target candidates\n");
+			System.exit(0);
+			// TODO put the feature not supported exception || implement it
+		}
+		
+		TEST_SET_1 = (int)(test_ratio * num_instances)+1;
+		System.out.println(test_ratio +"*"+ num_instances+" "+TEST_SET_1);
+		
+		crossed_set = cross_set(num_instances);
+		
+		ArrayList<InstanceList> sets = getSets(0);
+		InstanceList learning_set = sets.get(0);
+		InstanceList test_set = sets.get(1);
+					
+		dt = _trainer.train_tree(learning_set);
+		dt.setID(0);
+
+		
+		int error = 0;
+		SingleTreeTester t= new SingleTreeTester(dt);
+		if (slog.debug() !=null)
+			slog.debug().log("validation fold_size " +test_set.getSize() + "\n");
+		for (int index_i = 0; index_i < test_set.getSize(); index_i++) {
+			
+			if (slog.warn() !=null)
+				slog.warn().log(" validation index_i " +index_i + (index_i ==test_set.getSize() -1?"\n":""));
+			Integer result = t.test(test_set.getInstance(index_i));
+			if (result == Stats.INCORRECT) {
+				error ++;
+			}
+		}
+		dt.setValidationError(Util.division(error, test_set.getSize()));
+		dt.calc_num_node_leaves(dt.getRoot());
+		
+		if (slog.error() !=null)
+			slog.error().log("The estimate of : "+(0)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+	
+		/* moving averages */
+		error_estimate = dt.getValidationError();
+		training_error_estimate = (double)dt.getTrainingError();
+		num_leaves_estimate = (double)dt.getRoot().getNumLeaves();
+
+
+//		if (slog.stat() !=null)
+//			slog.stat().stat("."+ (i == k_fold?"\n":""));
+		alpha_estimate = (error_estimate - training_error_estimate) /num_leaves_estimate;
+		if (slog.stat() !=null)
+			slog.stat().log(" The estimates: training=" +training_error_estimate +" valid=" + error_estimate +" num_leaves=" + num_leaves_estimate+ " the alpha"+ alpha_estimate+"\n");
+		// TODO how to compute a best tree from the forest
+	}
+	
+	public ArrayList<InstanceList> getSets(int i) {
+//		// first part divide = 0; divide < fold_size*i
+//		// the validation set divide = fold_size*i; divide < fold_size*(i+1)-1
+//		// last part divide = fold_size*(i+1); divide < N
+		InstanceList learning_set = new InstanceList(class_instances.getSchema(), num_instances - TEST_SET_1 +1);
+		InstanceList validation_set = new InstanceList(class_instances.getSchema(), TEST_SET_1);
+		for (int divide_index = 0; divide_index < num_instances; divide_index++){
+			
+			if (slog.info() !=null)
+				slog.info().log("index " +divide_index+ " fold_size" + TEST_SET_1 + " = from "+TEST_SET_0+" to "+TEST_SET_1+" num_instances "+num_instances+ "\n");
+			if (divide_index >= TEST_SET_0 && divide_index <= TEST_SET_1) { // validation
+				// validation [fold_size*i, fold_size*(i+1))
+				if (slog.info() !=null)
+					slog.info().log("validation one " +divide_index+ "\n");
+				validation_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+			} else { // learninf part 
+				learning_set.addAsInstance(class_instances.getInstance(crossed_set[divide_index]));
+			}
+		}
+		
+		ArrayList<InstanceList> lists = new ArrayList<InstanceList>(2);
+		lists.add(learning_set);
+		lists.add(validation_set);
+		return lists;
+		
+	}
+
+	public double getAlphaEstimate() {
+		return alpha_estimate;
+	}
+
+	public double getErrorEstimate() {
+		return error_estimate;
+	}
+
+	public int getEstimatorSize() {
+		return 1;
+	}
+
+	public DecisionTree getEstimator(int i) {
+		return dt;
+	}
+
+	public int getTrainingDataSize(int i) {
+		return num_instances - TEST_SET_1;
+	}
+	public int getTestDataSize(int i) {
+		return TEST_SET_1;
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/TestSample.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,31 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+public class EstimatedNodeSize implements StoppingCriterion {
+	
+	private double outlier_percentage;
+	private int estimated_sum_branch;
+	private int num_times_branched;
+	
+	
+	public EstimatedNodeSize(double o_p) {
+		outlier_percentage = o_p;
+		num_times_branched = 0;
+	}
+	
+
+	public boolean stop(InformationContainer best_attr_eval) {
+		int d = best_attr_eval.getDepth();
+		estimated_sum_branch += best_attr_eval.domain.getCategoryCount();
+		num_times_branched ++;
+		double estimated_branch = (double)estimated_sum_branch/(double)num_times_branched;
+		// N/(b^d)
+		double estimated_size = best_attr_eval.getTotalNumData()/Math.pow(estimated_branch, d);
+		System.out.println("EstimatedNodeSize:stop: " +best_attr_eval.getNumData() + " <= " + ( Math.ceil(estimated_size*outlier_percentage)-1) +" / "+estimated_size);
+		if (best_attr_eval.getNumData() <= Math.ceil(estimated_size*outlier_percentage)-1)
+			return true;
+		else 
+			return false;
+	}
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/EstimatedNodeSize.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,20 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+
+public class ImpurityDecrease implements StoppingCriterion {
+
+	private double beta = 0.1;
+	
+	public ImpurityDecrease(double _beta) {
+		beta = _beta;
+	}
+	public boolean stop(InformationContainer best_attr_eval) {
+		if (best_attr_eval.attribute_eval < beta)
+			return true;
+		else
+			return false;
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/ImpurityDecrease.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,19 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+
+public class MaximumDepth implements StoppingCriterion {
+
+	private int limit_depth;
+	public MaximumDepth(int _depth) {
+		limit_depth = _depth;
+	}
+	public boolean stop(InformationContainer best_attr_eval) {
+		if (best_attr_eval.getDepth() <= limit_depth)
+			return false;
+		else 
+			return true;
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/MaximumDepth.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -0,0 +1,9 @@
+package org.drools.learner.eval.stopping;
+
+import org.drools.learner.eval.InformationContainer;
+
+public interface StoppingCriterion {
+	
+	public boolean stop(InformationContainer best_attr_eval);
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/stopping/StoppingCriterion.java
___________________________________________________________________
Name: svn:eol-style
   + native

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -66,8 +66,8 @@
 		return (double)x/(double)y;
 	}
 	
-	public static boolean epsilon(double d) {
-		return Math.abs(d) <= 0.0001;
+	public static boolean epsilon(double d, double epsilon) {
+		return Math.abs(d) <= epsilon; //0.000001;
 	}
 	
 	/* TODO make this all_fields arraylist as hashmap */

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -72,14 +72,17 @@
 //			decision_tree  = DecisionTreeFactory.createGlobal2(session, obj_class);
 //			break;
 		case 400: 
-			decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
+			decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
 			break;
 		case 500:
 			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
 			break;
 		case 600:
-			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
 			break;
+		case 601:
+			decision_tree = DecisionTreeFactory.createSingleTestPrunnedStopC45E(session, obj_class);
+			break;
 		default:
 			decision_tree  = DecisionTreeFactory.createSingleID3E(session, obj_class);
 		

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -75,13 +75,13 @@
 //			decision_tree  = DecisionTreeFactory.createGlobal2(session, obj_class);
 //			break;
 		case 400: 
-			decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
+			decision_tree = DecisionTreeFactory.createSingleCVPrunnedC45E(session, obj_class);
 			break;
 		case 500:
 			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
 			break;
 		case 600:
-			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
 			break;
 		default:
 			decision_tree  = DecisionTreeFactory.createSingleID3E(session, obj_class);

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java	2008-08-11 01:01:31 UTC (rev 21421)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java	2008-08-11 01:52:30 UTC (rev 21422)
@@ -71,7 +71,7 @@
 			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
 			break;
 		case 600:
-			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			decision_tree = DecisionTreeFactory.createSingleCrossPrunnedStopC45E(session, obj_class);
 			break;
 //			case 3:
 //			decision_tree  = DecisionTreeFactory.createGlobal2(session, obj_class);




More information about the jboss-svn-commits mailing list