[jboss-svn-commits] JBL Code SVN: r21321 - in labs/jbossrules/contrib/machinelearning/5.0: drools-core/src/main/java/org/drools/learner/builder and 3 other directories.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Fri Aug 1 09:12:34 EDT 2008


Author: gizil
Date: 2008-08-01 09:12:34 -0400 (Fri, 01 Aug 2008)
New Revision: 21321

Added:
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
Modified:
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
   labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
Log:
stopping criteria classes

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTree.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -135,7 +135,7 @@
 
 			return 1;
 		} else {
-			num_nonterminal_nodes ++;
+			num_nonterminal_nodes ++; // TODO does this really work?
 		}
 		
 		int leaves = 0;
@@ -150,6 +150,69 @@
 		return leaves;
 	}
 	
+	//private ArrayList<LeafNode> leaf_nodes;
+	public ArrayList<LeafNode> getLeaves(TreeNode start_node) {
+		ArrayList<LeafNode> terminal_nodes = new ArrayList<LeafNode>();
+		
+		find_leaves(terminal_nodes, start_node);
+
+		return terminal_nodes;
+	}
+	
+	private int find_leaves(ArrayList<LeafNode> terminals, TreeNode my_node) {
+		if (my_node instanceof LeafNode) {
+			terminals.add((LeafNode)my_node);
+			return 1;
+		} else {
+			num_nonterminal_nodes ++; // TODO does this really work?
+		}
+		
+		int leaves = 0;
+		for (Object child_key: my_node.getChildrenKeys()) {
+			/* split the last two class at the same time */
+			
+			TreeNode child = my_node.getChild(child_key);
+			leaves += find_leaves(terminals, child);
+			
+		}
+		//my_node.setNumLeaves(leaves);
+		return leaves;
+	}
+	
+	public ArrayList<TreeNode> getAnchestor_of_Leaves(TreeNode start_node) {
+		ArrayList<LeafNode> terminal_nodes = new ArrayList<LeafNode>();
+		
+		ArrayList<TreeNode> anc_terminal_nodes = new ArrayList<TreeNode>();
+		
+		find_leaves(terminal_nodes, anc_terminal_nodes, start_node);
+
+		return anc_terminal_nodes;
+	}
+	
+	private int find_leaves(ArrayList<LeafNode> terminals, ArrayList<TreeNode> anchestors, TreeNode my_node) {
+		
+		int leaves = 0;
+		boolean anchestor_added = false;
+		for (Object child_key: my_node.getChildrenKeys()) {
+			/* split the last two class at the same time */
+			
+			TreeNode child = my_node.getChild(child_key);
+			if (child instanceof LeafNode) {
+				terminals.add((LeafNode)my_node);
+				if (!anchestor_added) {
+					num_nonterminal_nodes ++; // TODO does this really work?
+					anchestors.add(my_node);
+					anchestor_added = true;
+				}
+				return 1;
+			} else {
+				leaves += find_leaves(terminals, anchestors, child);
+			}
+		}
+		//my_node.setNumLeaves(leaves);
+		return leaves;
+	}
+	
 	public int getNumNonTerminalNodes() {
 		return num_nonterminal_nodes;
 	}

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/DecisionTreePruner.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,30 +22,41 @@
 
 	private int num_trees_to_grow;
 	
+	private double INIT_ALPHA = 0.5d;
+	
 	public DecisionTreePruner(Estimator proc) {
 		procedure = proc;
 		num_trees_to_grow = procedure.getEstimatorSize();
 		//updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
 		
-		best_stats = new TreeStats(proc.getAlphaEstimate());
+		best_stats = new TreeStats(0.0);//proc.getAlphaEstimate());
 	} 	
 	
 	public void prun_to_estimate() {	
 		ArrayList<ArrayList<NodeUpdate>> updates = new ArrayList<ArrayList<NodeUpdate>>(num_trees_to_grow);
-		
 		ArrayList<ArrayList<TreeStats>> sequence_stats = new ArrayList<ArrayList<TreeStats>>(num_trees_to_grow);
+		ArrayList<MinAlphaProc> alpha_procs = new ArrayList<MinAlphaProc>(num_trees_to_grow);
+		
+		/*
+		 * The best tree is selected from this series of trees with the classification error not exceeding 
+		 * 	an expected error rate on some test set (cross-validation error), 
+		 * which is done at the second stage.
+		 */
+		double value_to_select = procedure.getErrorEstimate();
 //		private NodeUpdate best_update;
 		for (DecisionTree dt: procedure.getEstimators()) {
 			// dt.getId()
 			//dt.calc_num_node_leaves(dt.getRoot()); // this is done in the estimator
 			
-			TreeSequenceProc search = new TreeSequenceProc(dt, 100000.0d, new MinAlphaProc());
+			MinAlphaProc alpha_proc = new MinAlphaProc(INIT_ALPHA);
+			TreeSequenceProc search = new TreeSequenceProc(dt, alpha_proc);//INIT_ALPHA
 			 
 			search.iterate_trees(0);
 			
 			//updates.add(tree_sequence);
 			updates.add(search.getTreeSequence());
 			sequence_stats.add(search.getTreeSequenceStats());
+			alpha_procs.add(alpha_proc);
 			
 			// sort the found candidates
 			//Collections.sort(updates.get(dt.getId()), arg1)
@@ -62,127 +73,32 @@
 			System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
 			for (TreeStats st: sequence_stats.get(dt.getId()) ){
 				//System.out.println("Tree id\t Num_leaves\t Cross-validated\t Resubstitution\t Alpha\t");
-				System.out.println(sid);
+				System.out.println(sid+ "" +st.getAlpha() +" "+ st.getTest_cost());
 				sid++;
+				
 			}
 			int x =0;
 		}
 	
+	
 	}
 	
+	public void select_tree () {
+		/*
+		 * The best tree is selected from this series of trees with the classification error not exceeding 
+		 * an expected error rate on some test set (cross-validation error), 
+		 * which is done at the second stage.
+		 */
+		double value_to_select = procedure.getErrorEstimate();
+	
+	}
+	
 	public void prun_tree(DecisionTree tree) {
-		TreeSequenceProc search = new TreeSequenceProc(tree, best_stats.getAlpha(), new AnAlphaProc());
+		TreeSequenceProc search = new TreeSequenceProc(tree, new AnAlphaProc(best_stats.getAlpha()));
 		search.iterate_trees(0);
 		//search.getTreeSequence()// to go back
 		
 	}
-	
-//	private void sequence_trees(DecisionTree dt_0, double init_alpha, AlphaSelectionProc proc) {
-//		if (slog.debug() !=null)
-//			slog.debug().log(dt_0.toString() +"\n");
-//		// if number of non-terminal nodes in the tree is more than 1 
-//		// = if there exists at least one non-terminal node different than root
-//		if (dt_0.getNumNonTerminalNodes() < 1) {
-//			if (slog.debug() !=null)
-//				slog.debug().log(":sequence_trees:TERMINATE-There is no non-terminal nodes? " + dt_0.getNumNonTerminalNodes() +"\n");
-//			return; 
-//		} else if (dt_0.getNumNonTerminalNodes() == 1 && dt_0.getRoot().getNumLeaves()<=1) {
-//			if (slog.debug() !=null)
-//				slog.debug().log(":sequence_trees:TERMINATE-There is only one node left which is root node " + dt_0.getNumNonTerminalNodes()+ " and it has only one leaf (pruned)" +dt_0.getRoot().getNumLeaves()+"\n");
-//			return; 
-//		}
-//		// for each non-leaf subtree
-//		ArrayList<TreeNode> candidate_nodes = new ArrayList<TreeNode>();
-//		// to find all candidates with min_alpha value
-//		TreeSequenceProc search = new TreeSequenceProc(init_alpha, proc);//100000.0d, new MinAlphaProc()); 
-//		
-//		
-//		search.find_candidate_nodes(dt_0, dt_0.getRoot(), candidate_nodes);
-//		double min_alpha = search.getTheAlpha();
-//		System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
-//		
-//		if (candidate_nodes.size() >0) {
-//			// The one or more subtrees with that value of will be replaced by leaves
-//			// instead of getting the first node, have to process all nodes and prune all
-//			// write a method to prune all
-//			TreeNode best_node = candidate_nodes.get(0);
-//			LeafNode best_clone = new LeafNode(dt_0.getTargetDomain(), best_node.getLabel());
-//			best_clone.setRank(	best_node.getRank());
-//			best_clone.setNumMatch(best_node.getNumMatch());						//num of matching instances to the leaf node
-//			best_clone.setNumClassification(best_node.getNumLabeled());				//num of (correctly) classified instances at the leaf node
-//
-//			NodeUpdate update = new NodeUpdate(best_node, best_clone);
-//			//update.set
-//			update.setAlpha(min_alpha);
-//			update.setDecisionTree(dt_0);
-//			int k = numExtraMisClassIfPrun(best_node); // extra misclassified guys
-//			int num_leaves = best_node.getNumLeaves();
-//			int new_num_leaves = dt_0.getRoot().getNumLeaves() - num_leaves +1;
-//
-//			TreeNode father_node = best_node.getFather();
-//			if (father_node != null) {
-//				for(Object key: father_node.getChildrenKeys()) {
-//					if (father_node.getChild(key).equals(best_node)) {
-//						father_node.putNode(key, best_clone);
-//						break;
-//					}
-//				}
-//				updateLeaves(father_node, -num_leaves+1);
-//			} else {
-//				// this node does not have any father node it is the root node of the tree
-//				dt_0.setRoot(best_clone);
-//			}
-//			
-//			
-//			ArrayList<InstanceList> sets = procedure.getFold(dt_0.getId());
-//			//InstanceList learning_set = sets.get(0);
-//			InstanceList validation_set = sets.get(1);
-//			
-//			int error = 0;
-//			SingleTreeTester t= new SingleTreeTester(dt_0);
-//			for (int index_i = 0; index_i < validation_set.getSize(); index_i++) {
-//				Integer result = t.test(validation_set.getInstance(index_i));
-//				if (result == Stats.INCORRECT) {
-//					error ++;
-//				}
-//			}
-//			
-//			
-//			update.setCross_validated_cost(error);
-//			int new_resubstitution_cost = dt_0.getTrainingError() + k;
-//			double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
-//			
-//			
-//			if (slog.debug() !=null)
-//				slog.debug().log(":sequence_trees:cost_complexity of selected tree "+ cost_complexity +"\n");
-//			update.setResubstitution_cost(new_resubstitution_cost);
-//			// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
-//			update.setCost_complexity(cost_complexity);
-//			update.setNum_terminal_nodes(new_num_leaves);
-//
-//			updates.get(dt_0.getId()).add(update);
-//			
-//			if (slog.debug() !=null)
-//				slog.debug().log(":sequence_trees:error "+ error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
-//
-//			if (error < procedure.getValidationErrorEstimate() * 1.6) {
-//				// if the error of the tree is not that bad
-//				
-//				if (error < best_update.getCross_validated_cost()) {
-//					best_update = update;
-//					if (slog.debug() !=null)
-//						slog.debug().log(":sequence_trees:best node updated \n");
-//		
-//				}
-//					
-//				sequence_trees(dt_0);
-//			} else {
-//				update.setStopTree();
-//				return;
-//			}
-//		}
-//		
-//	}
 
 	private void updateLeaves(TreeNode my_node, int i) {
 		my_node.setNumLeaves(my_node.getNumLeaves() + i);
@@ -217,32 +133,27 @@
 		
 		private static final double MAX_ERROR_RATIO = 0.99;
 		private DecisionTree focus_tree;
-		private double the_alpha;
+		//private double the_alpha;
 		private AlphaSelectionProc alpha_proc;
 		private ArrayList<NodeUpdate> tree_sequence;
 		private ArrayList<TreeStats> tree_sequence_stats;
 		
 		private TreeStats best_tree_stats;
-		public TreeSequenceProc(DecisionTree dt, double init_alpha, AlphaSelectionProc cond) {
+		public TreeSequenceProc(DecisionTree dt, AlphaSelectionProc cond) { //, double init_alpha
 			focus_tree = dt;
-			the_alpha = init_alpha;
+			//the_alpha = init_alpha;
 			alpha_proc = cond;
 			tree_sequence = new ArrayList<NodeUpdate>();
+			tree_sequence_stats = new ArrayList<TreeStats>();
 			
 			best_tree_stats = new TreeStats(10000000.0d);
-
-//			init_tree.setResubstitution_cost(dt.getTrainingError());
-//			init_tree.setAlpha(-1);	// dont know
-//			init_tree.setCost_complexity(-1);	// dont known
-//			init_tree.setDecisionTree(dt);
-//			init_tree.setNum_terminal_nodes(dt.getRoot().getNumLeaves());
 			
 			NodeUpdate init_tree = new NodeUpdate(dt.getValidationError());
 			tree_sequence.add(init_tree);
 			
 			TreeStats init_tree_stats = new TreeStats(dt.getValidationError());
 			init_tree_stats.setResubstitution_cost(dt.getTrainingError());
-			init_tree_stats.setAlpha(-1);	// dont know
+			init_tree_stats.setAlpha(0.0d);	// dont know
 			init_tree_stats.setCost_complexity(-1);	// dont known
 //			init_tree_stats.setDecisionTree(dt);
 			init_tree_stats.setNum_terminal_nodes(dt.getRoot().getNumLeaves());		
@@ -281,7 +192,7 @@
 			find_candidate_nodes(focus_tree.getRoot(), candidate_nodes);
 			//double min_alpha = search.getTheAlpha();
 			double min_alpha = getTheAlpha();
-			System.out.println("!!!!!!!!!!!alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
+			System.out.println("!!!!!!!!!!! dt:"+focus_tree.getId()+" ite "+i+" alpha: "+min_alpha + " num_nodes_found "+candidate_nodes.size());
 			
 			if (candidate_nodes.size() >0) {
 				// The one or more subtrees with that value of will be replaced by leaves
@@ -340,7 +251,7 @@
 
 				int new_num_leaves = focus_tree.getRoot().getNumLeaves();
 				
-				double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, focus_tree.FACTS_READ);
+				double new_resubstitution_cost = focus_tree.getTrainingError() + Util.division(change_in_training_misclass, procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/);
 				double cost_complexity = new_resubstitution_cost + min_alpha * (new_num_leaves);
 				
 				
@@ -349,7 +260,7 @@
 				
 				
 				stats.setAlpha(min_alpha);
-				stats.setCross_validated_cost(percent_error);
+				stats.setTest_cost(percent_error);
 				stats.setResubstitution_cost(new_resubstitution_cost);
 				// Cost Complexity = Resubstitution Misclassification Cost + \alpha . Number of terminal nodes
 				stats.setCost_complexity(cost_complexity);
@@ -357,23 +268,27 @@
 				tree_sequence_stats.add(stats);
 				
 				if (slog.debug() !=null)
-					slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getValidationErrorEstimate() * 1.6 +"\n");
+					slog.debug().log(":sequence_trees:error "+ percent_error +"<?"+ procedure.getErrorEstimate() * 1.6 +"\n");
 
 				if (percent_error < MAX_ERROR_RATIO) { //procedure.getValidationErrorEstimate() * 1.6) {
 					// if the error of the tree is not that bad
 					
-					if (percent_error < best_tree_stats.getCross_validated_cost()) {
+					if (percent_error < best_tree_stats.getTest_cost()) {
 						best_tree_stats = stats;
 						if (slog.debug() !=null)
 							slog.debug().log(":sequence_trees:best node updated \n");
 			
 					}
-						
+					//	TODO update alpha_proc by increasing the min_alpha				the_alpha += 10;	
+					alpha_proc.init_proc(alpha_proc.getAlpha() + INIT_ALPHA);
 					iterate_trees(i+1);
 				} else {
 					//TODO update.setStopTree();
 					return;
 				}
+			} else {
+				if (slog.debug() !=null)
+					slog.debug().log(":sequence_trees:no candidate node is found ???? \n");
 			}
 			
 		}
@@ -391,11 +306,17 @@
 				int k = numExtraMisClassIfPrun(my_node);
 				int num_leaves = my_node.getNumLeaves();
 				
-				double alpha = ((double)k)/((double)focus_tree.FACTS_READ * (num_leaves-1));
+				if (k==0) {
+					if (slog.debug() !=null)
+						slog.debug().log(":search_alphas:k == 0\n" );
+					
+				}
+				double alpha = ((double)k)/((double)procedure.getTrainingDataSize(focus_tree.getId())/*focus_tree.FACTS_READ*/ * (num_leaves-1));
 				if (slog.debug() !=null)
-					slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+the_alpha+ " k "+k+" num_leaves "+num_leaves+" all "+ focus_tree.FACTS_READ +  "\n");
+					slog.debug().log(":search_alphas:alpha "+ alpha+ "/"+alpha_proc.getAlpha()+ " k "+k+" num_leaves "+num_leaves+" all "+ procedure.getTrainingDataSize(focus_tree.getId()) +  "\n");
 				
-				the_alpha = alpha_proc.update_nodes(alpha, the_alpha, my_node, nodes);		
+				//the_alpha = alpha_proc.check_node(alpha, the_alpha, my_node, nodes);
+				alpha_proc.check_node(alpha, my_node, nodes);
 				
 				for (Object attributeValue : my_node.getChildrenKeys()) {
 					TreeNode child = my_node.getChild(attributeValue);
@@ -407,7 +328,7 @@
 		}
 		
 		public double getTheAlpha() {
-			return the_alpha;
+			return alpha_proc.getAlpha();
 		}
 	}
 	
@@ -422,17 +343,22 @@
 	}
 	
 	public interface AlphaSelectionProc {
-		public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
+		public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes);
+		public void init_proc(double value);
+		public double getAlpha();
 	}
 	
 	public class AnAlphaProc implements AlphaSelectionProc{
-//		private double an_alpha;
-//		public AnAlphaProc(double value) {
-//			an_alpha = value;
-//		}
+		private double an_alpha;
+		public AnAlphaProc(double value) {
+			an_alpha = value;
+		}
+		public void init_proc(double value) {
+			// TODO ????
+		}
 		
-		public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
-			if (cur_alpha == the_alpha) {
+		public double check_node(double cur_alpha,TreeNode cur_node, ArrayList<TreeNode> nodes) {
+			if (Util.epsilon(cur_alpha - an_alpha)) {
 				for(TreeNode parent:nodes) {
 					if (isChildOf(cur_node, parent))
 						return cur_alpha;// it is not added
@@ -443,44 +369,64 @@
 			return cur_alpha;
 		}
 
-//		public double getAlpha() {
-//			return an_alpha;
-//		}
+		public double getAlpha() {
+			return an_alpha;
+		}
 	}
 	
 	public class MinAlphaProc implements AlphaSelectionProc{
-//		private double best_alpha;
-//		public MinAlphaProc(double value) {
-//			best_alpha = value;
-//		}
+		private double sum_min_alpha, init_min;
+		private int num_minimum;
+		public MinAlphaProc(double value) {
+			init_min = value;
+			sum_min_alpha = 0;
+			num_minimum = 0;
+		}
 		
-		public double update_nodes(double cur_alpha, double the_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
-			if (cur_alpha == the_alpha) {
+		public void init_proc(double value) {
+			init_min = value;
+			sum_min_alpha = 0;
+			num_minimum = 0;
+		}
+		
+		public double check_node(double cur_alpha, TreeNode cur_node, ArrayList<TreeNode> nodes) {
+			double average_of_min_alphas = getAlpha();
+			if (slog.debug() !=null)
+				slog.debug().log(":search_alphas:alpha "+ cur_alpha+ "/"+average_of_min_alphas+ " diff "+(cur_alpha - average_of_min_alphas)+"\n");
+			
+			if (Util.epsilon(cur_alpha - average_of_min_alphas)) {
+				// check if the cur_node is a child of any node that has been added to the list before.
 				for(TreeNode parent:nodes) {
 					if (isChildOf(cur_node, parent))
-						return cur_alpha;// it is not added
+						return cur_alpha;// if it is the case do not add the node
 				}
-				// add this one to the set
+				// else add this one to the set
 				nodes.add(cur_node);
+				sum_min_alpha += cur_alpha;
+				num_minimum ++;
 				return cur_alpha;
-			} else if (cur_alpha < the_alpha) {
+			} else if (cur_alpha < average_of_min_alphas) {
 				
 				nodes.clear(); // can not put a new 'cause then it does not update the global one = new ArrayList<TreeNode>();
 				// remove the ones you found and replace with that one
 				//tree_sequence.get(dt_id).put(my_node), alpha
-				
+				num_minimum = 1;
+				sum_min_alpha = cur_alpha;
 				nodes.add(cur_node);
 				return cur_alpha;
 				
 			} else {
 				
 			}
-			return the_alpha;
+			return sum_min_alpha/num_minimum;
 		}
 
-//		public double getAlpha() {
-//			return best_alpha;
-//		}
+		public double getAlpha() {
+			if (num_minimum == 0)
+				return init_min;
+			else
+				return sum_min_alpha/num_minimum;
+		}
 	}
 	
 	public class NodeUpdate{
@@ -531,7 +477,7 @@
 
 		private int iteration_id;
 		private int num_terminal_nodes;
-		private double cross_validated_cost;
+		private double test_cost;
 		private double resubstitution_cost;
 		private double cost_complexity;
 		private double alpha;
@@ -542,7 +488,7 @@
 		// to set an node update with the worst cross validated error
 		public TreeStats(double error) {
 			iteration_id = 0;
-			cross_validated_cost = error;
+			test_cost = error;
 		}
 		
 		public void iteration_id(int i) {
@@ -556,12 +502,12 @@
 			this.num_terminal_nodes = num_terminal_nodes;
 		}
 
-		public double getCross_validated_cost() {
-			return cross_validated_cost;
+		public double getTest_cost() {
+			return test_cost;
 		}
 
-		public void setCross_validated_cost(double cross_validated_cost) {
-			this.cross_validated_cost = cross_validated_cost;
+		public void setTest_cost(double valid_cost) {
+			this.test_cost = valid_cost;
 		}
 
 		public double getResubstitution_cost() {

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/TreeNode.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -27,6 +27,8 @@
 	private int label_size;
 	private int leaves;
 	
+	private int depth;
+	
 	public TreeNode(Domain domain) {
 		this.father = null;
 		this.domain = domain;
@@ -34,6 +36,14 @@
 		
 	}
 	
+	public void setDepth(int d) {
+		depth = d;
+	}
+	
+	public int getDepth() {
+		return depth;
+	}
+	
 	public double getRank() {
 		return rank;
 	}

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -52,8 +52,11 @@
 		}
 	
 		int N = class_instances.getSize();
-		int NUM_DATA = (int)(TREE_SIZE_RATIO * N);
-		_trainer.setDataSizePerTree(NUM_DATA);
+		int NUM_DATA = (int)(TREE_SIZE_RATIO * N);	// TREE_SIZE_RATIO = 1.0, all data is used to train the trees again again
+		_trainer.setTrainingDataSizePerTree(NUM_DATA);
+		
+		/* all data fed to each tree, the same data?? */
+		_trainer.setTrainingDataSize(NUM_DATA); // TODO????
 
 		
 		forest = new ArrayList<DecisionTree> (FOREST_SIZE);

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/AdaBoostKBuilder.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -48,9 +48,13 @@
 		}
 	
 		int N = class_instances.getSize();
+		//_trainer.setTrainingDataSize(N); not only N data is fed. 
+		
 		int K = _trainer.getTargetDomain().getCategoryCount();
 		int M = (int)(TREE_SIZE_RATIO * N);
-		_trainer.setDataSizePerTree(M);
+		_trainer.setTrainingDataSizePerTree(M);
+		/* M data fed to each tree, there are FOREST_SIZE trees*/
+		_trainer.setTrainingDataSize(M * FOREST_SIZE); 
 
 		
 		forest = new ArrayList<DecisionTree> (FOREST_SIZE);
@@ -61,7 +65,6 @@
 			for (int index_j=0; index_j<K; index_j++) {
 				Instance inst_i = class_instances.getInstance(index_i);
 				
-				
 				Object instance_target = inst_i.getAttrValue(_trainer.getTargetDomain().getFReferenceName());
 				Object instance_target_category = _trainer.getTargetDomain().getCategoryOf(instance_target);
 				Object target_category= _trainer.getTargetDomain().getCategory(index_j);

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -1,5 +1,6 @@
 package org.drools.learner.builder;
 
+import java.util.ArrayList;
 import java.util.Hashtable;
 import java.util.List;
 
@@ -11,21 +12,31 @@
 import org.drools.learner.eval.Heuristic;
 import org.drools.learner.eval.InformationContainer;
 import org.drools.learner.eval.InstDistribution;
+import org.drools.learner.eval.StoppingCriterion;
 import org.drools.learner.tools.FeatureNotSupported;
 import org.drools.learner.tools.Util;
 
 public class C45Learner extends Learner{
 	
 	private AttributeChooser chooser;
+	private ArrayList<StoppingCriterion> criteria; 
 	
 	public C45Learner(Heuristic hf) {
 		super();
 		super.setDomainAlgo(DomainAlgo.QUANTITATIVE);
 		chooser = new AttributeChooser(hf);
+		criteria = null;
 	}
 	
 	
-	protected TreeNode train(DecisionTree dt, InstDistribution data_stats) {//List<Instance> data) {
+	public C45Learner(Heuristic hf, ArrayList<StoppingCriterion> _criteria) {
+		super();
+		super.setDomainAlgo(DomainAlgo.QUANTITATIVE);
+		chooser = new AttributeChooser(hf);
+		criteria = _criteria;
+	}
+	
+	protected TreeNode train(DecisionTree dt, InstDistribution data_stats,  int depth) {//List<Instance> data) {
 		
 		if (data_stats.getSum() == 0) {
 			throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -42,7 +53,7 @@
 			LeafNode classifiedNode = new LeafNode(dt.getTargetDomain()				/* target domain*/, 
 												   data_stats.get_winner_class() 	/*winner target category*/);
 			classifiedNode.setRank(	(double)data_stats.getSum()/
-									(double)this.getDataSize()/* total size of data fed to dt*/);
+									(double)this.getTrainingDataSize()/* total size of data fed to dt*/);
 			classifiedNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			classifiedNode.setNumClassification(data_stats.getSum());				//num of classified instances at the leaf node
 			
@@ -58,7 +69,7 @@
 			LeafNode noAttributeLeftNode = new LeafNode(dt.getTargetDomain()			/* target domain*/, 
 														winner);
 			noAttributeLeftNode.setRank((double)data_stats.getVoteFor(winner)/
-										(double)this.getDataSize()						/* total size of data fed to dt*/);
+										(double)this.getTrainingDataSize()						/* total size of data fed to dt*/);
 			noAttributeLeftNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner));	//num of classified instances at the leaf node
 			//noAttributeLeftNode.setInfoMea(best_attr_eval.attribute_eval);
@@ -66,24 +77,42 @@
 			
 			/* we need to know how many guys cannot be classified and who these guys are */
 			data_stats.missClassifiedInstances(missclassified_data);
-			dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/dt.FACTS_READ);
+			dt.setTrainingError(dt.getTrainingError() + data_stats.getSum()/getTrainingDataSize());
 			return noAttributeLeftNode;
 		}
-		
 	
 		InformationContainer best_attr_eval = new InformationContainer();
-		
+		best_attr_eval.setStats(data_stats);
+		best_attr_eval.setDepth(depth);
+		best_attr_eval.setTotalNumData(getTrainingDataSizePerTree());
+
 		/* choosing the best attribute in order to branch at the current node*/
 		chooser.chooseAttribute(best_attr_eval, data_stats, attribute_domains);
+		
+		if (criteria != null & criteria.size()>0) {
+			for (StoppingCriterion sc: criteria) 
+				if (sc.stop(best_attr_eval)) {
+					Object winner = data_stats.get_winner_class();
+					LeafNode majorityNode = new LeafNode(dt.getTargetDomain(), winner);
+					majorityNode.setRank((double)data_stats.getVoteFor(winner)/
+										 (double)this.getTrainingDataSize()						/* total size of data fed to trainer*/);
+					majorityNode.setNumMatch(data_stats.getSum());
+					majorityNode.setNumClassification(data_stats.getVoteFor(winner));
+					
+					/* we need to know how many guys cannot be classified and who these guys are */
+					data_stats.missClassifiedInstances(missclassified_data);
+					dt.setTrainingError(dt.getTrainingError() + (data_stats.getSum()-data_stats.getVoteFor(winner))/getTrainingDataSize());
+					return majorityNode;
+				}
+		}
 		Domain node_domain = best_attr_eval.domain;
-		
 		if (slog.debug() != null)
 			slog.debug().log("\n"+Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
 
 		TreeNode currentNode = new TreeNode(node_domain);
 		currentNode.setNumMatch(data_stats.getSum());									//num of matching instances to the leaf node
 		currentNode.setRank((double)data_stats.getSum()/
-							(double)this.getDataSize()									/* total size of data fed to dt*/);
+							(double)this.getTrainingDataSize()									/* total size of data fed to trainer*/);
 		currentNode.setInfoMea(best_attr_eval.attribute_eval);
 		//what the highest represented class is and what proportion of items at that node actually are that class
 		currentNode.setLabel(data_stats.get_winner_class());
@@ -107,6 +136,7 @@
 			
 			/* list of domains except the choosen one (&target domain)*/
 			DecisionTree child_dt = new DecisionTree(dt, node_domain);	
+			child_dt.FACTS_READ = dt.FACTS_READ;
 			
 			if (filtered_stats == null || filtered_stats.get(category) == null || filtered_stats.get(category).getSum() ==0) {
 				/* majority !!!! */
@@ -120,7 +150,7 @@
 				majorityNode.setFather(currentNode);
 				currentNode.putNode(category, majorityNode);
 			} else {
-				TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+				TreeNode newNode = train(child_dt, filtered_stats.get(category), depth+1);//, attributeNames_copy
 				newNode.setFather(currentNode);
 				currentNode.putNode(category, newNode);
 			}

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeFactory.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -1,5 +1,7 @@
 package org.drools.learner.builder;
 
+import java.util.ArrayList;
+
 import org.drools.WorkingMemory;
 import org.drools.learner.DecisionTree;
 import org.drools.learner.DecisionTreePruner;
@@ -9,8 +11,10 @@
 import org.drools.learner.builder.DecisionTreeBuilder.TreeAlgo;
 import org.drools.learner.eval.CrossValidation;
 import org.drools.learner.eval.Entropy;
+import org.drools.learner.eval.EstimatedNodeSize;
 import org.drools.learner.eval.GainRatio;
 import org.drools.learner.eval.Heuristic;
+import org.drools.learner.eval.StoppingCriterion;
 import org.drools.learner.tools.FeatureNotSupported;
 import org.drools.learner.tools.Util;
 
@@ -201,6 +205,92 @@
 		return learner.getTree();
 	}
 	
+	
+	public static DecisionTree createSingleC45E_StoppingCriteria(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		return createSingleC45_Stop(wm, obj_class, new Entropy());
+	}
+	
+//	public static DecisionTree createSingleC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+//		return createSingleC45(wm, obj_class, new GainRatio());
+//	}
+	
+	protected static DecisionTree createSingleC45_Stop(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+		DataType data = Learner.DEFAULT_DATA;
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		stopping_criteria.add(new EstimatedNodeSize(0.5));
+		C45Learner learner = new C45Learner(h, stopping_criteria);
+		SingleTreeBuilder single_builder = new SingleTreeBuilder();
+		
+//		String algo_suffices = org.drools.learner.deprecated.DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+//		String executionSignature = org.drools.learner.deprecated.DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+		String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+		String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+		
+		/* create the memory */
+		Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+		single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+		
+		SingleTreeTester tester = new SingleTreeTester(learner.getTree());
+		tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
+		//Tester.test(c45, mem.getClassInstances());
+		
+		learner.getTree().setSignature(executionSignature);
+		return learner.getTree();
+	}
+	
+	
+	public static DecisionTree createSinglePrunnedStopC45E(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		return createSinglePrunnedStopC45(wm, obj_class, new Entropy());
+	}
+	public static DecisionTree createSinglePrunnedStopC45G(WorkingMemory wm, Class<? extends Object> obj_class) throws FeatureNotSupported {
+		return createSinglePrunnedStopC45(wm, obj_class, new GainRatio());
+	}
+	
+	protected static DecisionTree createSinglePrunnedStopC45(WorkingMemory wm, Class<? extends Object> obj_class, Heuristic h) throws FeatureNotSupported {
+		DataType data = Learner.DEFAULT_DATA;
+		ArrayList<StoppingCriterion> stopping_criteria = new ArrayList<StoppingCriterion>();
+		stopping_criteria.add(new EstimatedNodeSize(0.05));
+		C45Learner learner = new C45Learner(h, stopping_criteria);
+		
+		SingleTreeBuilder single_builder = new SingleTreeBuilder();
+	
+		String algo_suffices = DecisionTreeFactory.getAlgoSuffices(learner.getDomainAlgo(), single_builder.getTreeAlgo());
+		String executionSignature = DecisionTreeFactory.getSignature(obj_class, "", algo_suffices);
+		
+		/* create the memory */
+		Memory mem = Memory.createFromWorkingMemory(wm, obj_class, learner.getDomainAlgo(), data);
+		single_builder.build(mem, learner);//obj_class, target_attr, working_attr
+		
+		CrossValidation validater = new CrossValidation(10, mem.getClassInstances());
+		validater.validate(learner);
+		
+		DecisionTreePruner pruner = new DecisionTreePruner(validater);
+		pruner.prun_to_estimate();
+		
+		// you should be able to get the pruned tree
+		// prun.getMinimumCostTree()
+		// prun.getOptimumCostTree()
+		
+		// test the tree
+		SingleTreeTester tester = new SingleTreeTester(learner.getTree());
+		tester.printStats(tester.test(mem.getClassInstances()), Util.DRL_DIRECTORY + executionSignature);
+		
+		/* Once Talpha is found the tree that is finally suggested for use is that 
+		 * which minimises the cost-complexity using and all the data use the pruner to prun the tree
+		 */
+		pruner.prun_tree(learner.getTree());
+		
+		
+		// test the tree again
+		
+		
+		//Tester.test(c45, mem.getClassInstances());
+		
+		learner.getTree().setSignature(executionSignature);
+		return learner.getTree();
+	}
+	
+	
 	public static String getSignature(Class<? extends Object> obj_class, String fileName, String suffices) {
 		
 		//String fileName = (dataFile == null || dataFile == "") ? this.getRuleClass().getSimpleName().toLowerCase(): dataFile; 			

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -43,8 +43,12 @@
 		}
 	
 		int N = class_instances.getSize();
+		// _trainer.setTrainingDataSize(N); => wrong
 		int tree_capacity = (int)(TREE_SIZE_RATIO * N);
-		_trainer.setDataSizePerTree(tree_capacity);
+		_trainer.setTrainingDataSizePerTree(tree_capacity);
+		
+		/* tree_capacity number of data fed to each tree, there are FOREST_SIZE trees*/
+		_trainer.setTrainingDataSize(tree_capacity * FOREST_SIZE); 
 
 		
 		forest = new ArrayList<DecisionTree> (FOREST_SIZE);

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,7 +22,7 @@
 		chooser = new AttributeChooser(hf);
 	}
 	
-	protected TreeNode train(DecisionTree dt, InstDistribution data_stats) {//List<Instance> data) {
+	protected TreeNode train(DecisionTree dt, InstDistribution data_stats, int depth) {//List<Instance> data) {
 		
 		if (data_stats.getSum() == 0) {
 			throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -39,7 +39,7 @@
 			LeafNode classifiedNode = new LeafNode(dt.getTargetDomain()				/* target domain*/, 
 												   data_stats.get_winner_class() 	/*winner target category*/);
 			classifiedNode.setRank(	(double)data_stats.getSum()/
-									(double)this.getDataSize()/* total size of data fed to dt*/);
+									(double)this.getTrainingDataSize()/* total size of data fed to dt*/);
 			classifiedNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			classifiedNode.setNumClassification(data_stats.getSum());				//num of classified instances at the leaf node
 			return classifiedNode;
@@ -53,7 +53,7 @@
 			LeafNode noAttributeLeftNode = new LeafNode(dt.getTargetDomain()			/* target domain*/, 
 														winner);
 			noAttributeLeftNode.setRank((double)data_stats.getVoteFor(winner)/
-										(double)this.getDataSize()						/* total size of data fed to dt*/);
+										(double)this.getTrainingDataSize()						/* total size of data fed to dt*/);
 			noAttributeLeftNode.setNumMatch(data_stats.getSum());						//num of matching instances to the leaf node
 			noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner));	//num of classified instances at the leaf node
 			return noAttributeLeftNode;
@@ -89,7 +89,7 @@
 				majorityNode.setNumClassification(0);
 				currentNode.putNode(category, majorityNode);
 			} else {
-				TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
+				TreeNode newNode = train(child_dt, filtered_stats.get(category), depth+1);//, attributeNames_copy
 				currentNode.putNode(category, newNode);
 			}
 		}

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/Learner.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -22,7 +22,7 @@
 	
 	public static enum DataType {PRIMITIVE, STRUCTURED, COLLECTION}
 	public static DataType DEFAULT_DATA = DataType.PRIMITIVE;
-	private int data_size;
+	private int data_size, data_size_per_tree;
 	private DecisionTree best_tree;
 	private InstanceList input_data;
 	protected HashSet<Instance> missclassified_data;
@@ -31,10 +31,11 @@
 	private DomainAlgo algorithm;
 	
 	
-	protected abstract TreeNode train(DecisionTree dt, InstDistribution data_stats);
+	protected abstract TreeNode train(DecisionTree dt, InstDistribution data_stats, int depth);
 	
 	public Learner() {
 		this.data_size = 0;
+		this.data_size_per_tree = 0;
 	}
 
 	
@@ -47,9 +48,11 @@
 		
 		InstDistribution stats_by_class = new InstDistribution(dt.getTargetDomain());
 		stats_by_class.calculateDistribution(working_instances.getInstances());
+		
+		
 		dt.FACTS_READ += working_instances.getSize();
 
-		TreeNode root = train(dt, stats_by_class);
+		TreeNode root = train(dt, stats_by_class, 0);
 		dt.setRoot(root);
 		//flog.debug("Result tree\n" + dt);
 		return dt;
@@ -76,7 +79,7 @@
 			stats_by_class.calculateDistribution(working_instances.getInstances());
 			dt.FACTS_READ += working_instances.getSize();
 			
-			TreeNode root = train(dt, stats_by_class);
+			TreeNode root = train(dt, stats_by_class, 0);
 			dt.setRoot(root);
 			//flog.debug("Result tree\n" + dt);
 		}
@@ -84,13 +87,20 @@
 	}
 	
 	
-	public void setDataSizePerTree(int num) {
-		this.data_size = num;
+	public void setTrainingDataSizePerTree(int num) {
+		this.data_size_per_tree = num;
 		
 		missclassified_data = new HashSet<Instance>();
 	}
 	
-	public int getDataSize() {
+	public int getTrainingDataSizePerTree() {
+		return this.data_size_per_tree;
+	}
+	
+	public void setTrainingDataSize(int num) {
+		this.data_size = num;
+	}
+	public int getTrainingDataSize() {
 		return this.data_size;
 	}
 	

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -39,8 +39,8 @@
 			System.exit(0);
 			// TODO put the feature not supported exception || implement it
 		}
-		
-		_trainer.setDataSizePerTree(class_instances.getSize());
+		_trainer.setTrainingDataSize(class_instances.getSize());
+		_trainer.setTrainingDataSizePerTree(class_instances.getSize());
 		one_tree = _trainer.train_tree(class_instances);
 		_trainer.setBestTree(one_tree);
 	}

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/CrossValidation.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -106,10 +106,11 @@
 			dt.calc_num_node_leaves(dt.getRoot());
 			
 			if (slog.error() !=null)
-				slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + error +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
+				slog.error().log("The estimate of : "+(i-1)+" training=" +dt.getTrainingError() +" valid=" + dt.getValidationError() +" num_leaves=" + dt.getRoot().getNumLeaves()+"\n");
 		
+			/* moving averages */
 			validation_error_estimate += ((double)error/(double) fold_size)/(double)k_fold;
-			training_error_estimate += ((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
+			training_error_estimate += ((double)dt.getTrainingError())/(double)k_fold;//((double)dt.getTrainingError()/(double)(num_instances-fold_size))/(double)k_fold;
 			num_leaves_estimate += (double)dt.getRoot().getNumLeaves()/(double)k_fold;
 
 
@@ -190,6 +191,10 @@
 //		
 //	}
 	
+	public int getTrainingDataSize(int i) {
+		return num_instances-getFoldSize(i);
+	}
+	
 	public double getAlphaEstimate() {
 		return alpha_estimate;
 	}
@@ -197,7 +202,7 @@
 		int excess = num_instances % k_fold;
 		return (int) num_instances/k_fold + (i < excess? 1:0);
 	}
-	public double getValidationErrorEstimate() {
+	public double getErrorEstimate() {
 		return validation_error_estimate;
 	}
 	

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,29 @@
+package org.drools.learner.eval;
+
+public class EstimatedNodeSize implements StoppingCriterion {
+	
+	private double outlier_percentage;
+	private int estimated_sum_branch;
+	private int num_times_branched;
+	
+	
+	public EstimatedNodeSize(double o_p) {
+		outlier_percentage = o_p;
+		num_times_branched = 0;
+	}
+	
+
+	public boolean stop(InformationContainer best_attr_eval) {
+		int d = best_attr_eval.getDepth();
+		estimated_sum_branch += best_attr_eval.domain.getCategoryCount();
+		num_times_branched ++;
+		double estimated_branch = (double)estimated_sum_branch/(double)num_times_branched;
+		// N/(b^d)
+		double estimated_size = best_attr_eval.getTotalNumData()/Math.pow(estimated_branch, d);
+		System.out.println("EstimatedNodeSize:stop: " +best_attr_eval.getNumData() + " <= " + ( Math.ceil(estimated_size*outlier_percentage)-1) +" / "+estimated_size);
+		if (best_attr_eval.getNumData() <= Math.ceil(estimated_size*outlier_percentage)-1)
+			return true;
+		else 
+			return false;
+	}
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/EstimatedNodeSize.java
___________________________________________________________________
Name: svn:eol-style
   + native

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/Estimator.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -10,7 +10,8 @@
 	public int getEstimatorSize();
 	public ArrayList<DecisionTree> getEstimators();
 	public ArrayList<InstanceList> getFold(int id);
+	public int getTrainingDataSize(int i);
 	
-	public double getValidationErrorEstimate();
+	public double getErrorEstimate();
 	public double getAlphaEstimate();
 }

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,18 @@
+package org.drools.learner.eval;
+
+
+public class ImpurityDecrease implements StoppingCriterion {
+
+	private double beta = 0.1;
+	
+	public ImpurityDecrease(double _beta) {
+		beta = _beta;
+	}
+	public boolean stop(InformationContainer best_attr_eval) {
+		if (best_attr_eval.attribute_eval < beta)
+			return true;
+		else
+			return false;
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/ImpurityDecrease.java
___________________________________________________________________
Name: svn:eol-style
   + native

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -12,13 +12,44 @@
 	//public double gain_ratio;
 	public ArrayList<Instance> sorted_data;
 	
+	private InstDistribution stats;
+	private int depth;
+	private int total_num_data;
+	
 	public InformationContainer() {
 		domain = null;
 		attribute_eval = 0.0;
 		sorted_data = null;
 
+		depth = 0;
+		stats = null;
+		total_num_data = 0;
 	}
+
+	public void setStats(InstDistribution data_stats) {
+		stats = data_stats;
+	}
+
+	public void setDepth(int _depth) {
+		depth = _depth;
+	}
+
+	public int getDepth() {
+		return depth;
+	}
+	public double getNumData() {
+		return stats.getSum();
+	}
+
+	// total num of data fed to per tree
+	public void setTotalNumData(int num) {
+		total_num_data = num;
+	}
 	
+	public int getTotalNumData() {
+		return total_num_data ;
+	}
+	
 //	public InformationContainer(Domain _domain, double _attribute_eval, double _gain_ratio) {
 //		this.domain = _domain;
 //		this.attribute_eval = _attribute_eval;

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -99,8 +99,7 @@
 	}
 
 	
-	public Hashtable<Object, InstDistribution> splitFromCategorical(
-			Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
+	public Hashtable<Object, InstDistribution> splitFromCategorical(Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
 		if (instLists == null)
 			instLists = this.instantiateLists(splitDomain);
 		
@@ -122,8 +121,7 @@
 		return instLists;
 	}
 	
-	private void splitFromQuantitative(ArrayList<Instance> data, 
-				QuantitativeDomain attributeDomain, Hashtable<Object, InstDistribution> instLists) {
+	private void splitFromQuantitative(ArrayList<Instance> data, QuantitativeDomain attributeDomain, Hashtable<Object, InstDistribution> instLists) {
 		
 		String attributeName = attributeDomain.getFName();
 		String targetName = super.getClassDomain().getFReferenceName();

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,17 @@
+package org.drools.learner.eval;
+
+
+public class MaximumDepth implements StoppingCriterion {
+
+	private int limit_depth;
+	public MaximumDepth(int _depth) {
+		limit_depth = _depth;
+	}
+	public boolean stop(InformationContainer best_attr_eval) {
+		if (best_attr_eval.getDepth() <= limit_depth)
+			return false;
+		else 
+			return true;
+	}
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/MaximumDepth.java
___________________________________________________________________
Name: svn:eol-style
   + native

Added: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -0,0 +1,9 @@
+package org.drools.learner.eval;
+
+import org.drools.learner.DecisionTree;
+
+public interface StoppingCriterion {
+	
+	public boolean stop(InformationContainer best_attr_eval);
+
+}


Property changes on: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/eval/StoppingCriterion.java
___________________________________________________________________
Name: svn:eol-style
   + native

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-core/src/main/java/org/drools/learner/tools/Util.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -66,6 +66,10 @@
 		return (double)x/(double)y;
 	}
 	
+	public static boolean epsilon(double d) {
+		return Math.abs(d) <= 0.0001;
+	}
+	
 	/* TODO make this all_fields arraylist as hashmap */
 	public static void getSuperFields(Class<?> clazz, ArrayList<Field> all_fields) {
 		if (clazz == Object.class)

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -37,7 +37,7 @@
 		}
 
 		// instantiate a learner for a specific object class and pass session to train
-		DecisionTree decision_tree; int ALGO = 400;
+		DecisionTree decision_tree; int ALGO = 600;
 		/* 
 		 * Single	1xx, Bag 	2xx, Boost 3xx
 		 * ID3 		x1x, C45 	x2x
@@ -74,6 +74,12 @@
 		case 400: 
 			decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
 			break;
+		case 500:
+			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+			break;
+		case 600:
+			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			break;
 		default:
 			decision_tree  = DecisionTreeFactory.createSingleID3E(session, obj_class);
 		

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -40,7 +40,7 @@
 			session.insert(r);
 		}
 
-		DecisionTree decision_tree; int ALGO = 400;
+		DecisionTree decision_tree; int ALGO = 600;
 		/* 
 		 * Single	1xx, Bag 	2xx, Boost 3xx
 		 * ID3 		x1x, C45 	x2x
@@ -77,6 +77,12 @@
 		case 400: 
 			decision_tree = DecisionTreeFactory.createSinglePrunnedC45E(session, obj_class);
 			break;
+		case 500:
+			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+			break;
+		case 600:
+			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			break;
 		default:
 			decision_tree  = DecisionTreeFactory.createSingleID3E(session, obj_class);
 		

Modified: labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java	2008-08-01 05:18:28 UTC (rev 21320)
+++ labs/jbossrules/contrib/machinelearning/5.0/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java	2008-08-01 13:12:34 UTC (rev 21321)
@@ -36,7 +36,7 @@
 		}
 
 		// instantiate a learner for a specific object class and pass session to train
-		DecisionTree decision_tree; int ALGO = 221;
+		DecisionTree decision_tree; int ALGO = 600;
 		/* 
 		 * Single	1xx, Bag 	2xx, Boost 3xx
 		 * ID3 		x1x, C45 	x2x
@@ -67,6 +67,12 @@
 		case 322:
 			decision_tree  = DecisionTreeFactory.createBoostedC45G(session, obj_class);
 			break;
+		case 500:
+			decision_tree = DecisionTreeFactory.createSingleC45E_StoppingCriteria(session, obj_class);
+			break;
+		case 600:
+			decision_tree = DecisionTreeFactory.createSinglePrunnedStopC45E(session, obj_class);
+			break;
 //			case 3:
 //			decision_tree  = DecisionTreeFactory.createGlobal2(session, obj_class);
 //			break;




More information about the jboss-svn-commits mailing list