[jboss-svn-commits] JBL Code SVN: r19372 - in labs/jbossrules/contrib/machinelearning/decisiontree/src: dt/builder and 3 other directories.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Tue Apr 1 21:29:16 EDT 2008


Author: gizil
Date: 2008-04-01 21:29:16 -0400 (Tue, 01 Apr 2008)
New Revision: 19372

Added:
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java
Removed:
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
Modified:
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
Log:
optimizing before recursive discretization

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -5,9 +5,6 @@
 import java.util.List;
 
 import dt.memory.Domain;
-import dt.memory.Fact;
-import dt.memory.FactTargetDistribution;
-import dt.tools.Util;
 
 public class DecisionTree {
 
@@ -33,93 +30,7 @@
 		this.attrsToClassify = new ArrayList<String>();
 	}
 
-	public Object getMajority(List<Fact> facts) {
-		List<?> targetValues = getPossibleValues(this.target);
-		Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target);
 
-		int winner_vote = 0;
-		Object winner = null;
-		for (Object key : targetValues) {
-
-			int num_in_class = facts_in_class.get(key).intValue();
-			if (num_in_class > winner_vote) {
-				winner_vote = num_in_class;
-				winner = key;
-			}
-		}
-		return winner;
-	}
-
-	// *OPT* public double getInformation(List<FactSet> facts) {
-	public Hashtable<Object, Integer> getStatistics(List<Fact> facts, String target) {
-
-		List<?> targetValues = getPossibleValues(this.target);
-		Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(
-				targetValues.size());
-
-		for (Object t : targetValues) {
-			facts_in_class.put(t, 0);
-		}
-
-		int total_num_facts = 0;
-		// *OPT* for (FactSet fs: facts) {
-		// *OPT* for (Fact f: fs.getFacts()) {
-		for (Fact f : facts) {
-			total_num_facts++;
-			Object key = f.getFieldValue(target);
-			// System.out.println("My key: "+ key.toString());
-			facts_in_class.put(key, facts_in_class.get(key).intValue() + 1); // bocuk
-			// kafa
-			// :P
-		}
-		FACTS_READ += facts.size();
-		// *OPT* }
-		// *OPT* }
-		return facts_in_class;
-	}
-	
-	// *OPT* public double getInformation(List<FactSet> facts) {
-	public FactTargetDistribution getDistribution(List<Fact> facts) {
-		
-		FactTargetDistribution facts_in_class = new FactTargetDistribution(getDomain(getTarget()));
-		facts_in_class.calculateDistribution(facts);
-		FACTS_READ += facts.size();
-		return facts_in_class;
-	}
-
-	// *OPT* public double getInformation(List<FactSet> facts) {
-	/**
-	 * it returns the information value of facts entropy that characterizes the
-	 * (im)purity of an arbitrary collection of examples
-	 * 
-	 * @param facts
-	 *            list of facts
-	 */
-	public double getInformation_old(List<Fact> facts) {
-
-		List<?> targetValues = getPossibleValues(this.target);
-		Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
-				getTarget()); // , targetValues)
-		// Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
-		// getTarget(), targetValues);
-		int total_num_facts = facts.size();
-		double sum = 0;
-		for (Object key : targetValues) {
-			int num_in_class = facts_in_class.get(key).intValue();
-			// System.out.println("num_in_class : "+ num_in_class + " key "+ key
-			// + " and the total num "+ total_num_facts);
-			double prob = (double) num_in_class / (double) total_num_facts;
-
-			// double log2= Util.log2(prob);
-			// double plog2p= prob*log2;
-			sum += (prob == 0.0) ? 0.0 : -1 * prob * Util.log2(prob);
-			// System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
-			// where the sum: "+sum);
-		}
-		return sum;
-	}
-
-
 	public void setTarget(String targetField) {
 		target = targetField;
 		attrsToClassify.remove(target);

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -11,6 +11,7 @@
 import dt.LeafNode;
 import dt.TreeNode;
 
+import dt.memory.FactDistribution;
 import dt.memory.FactTargetDistribution;
 import dt.memory.WorkingMemory;
 import dt.memory.Fact;
@@ -83,7 +84,7 @@
 		}
 		dt.FACTS_READ += facts.size();
 
-		num_fact_processed = facts.size();
+		setNum_fact_processed(facts.size());
 
 		if (workingAttributes != null)
 			for (String attr : workingAttributes) {
@@ -123,7 +124,7 @@
 			}
 		}
 		dt.FACTS_READ += facts.size();
-		num_fact_processed = facts.size();
+		setNum_fact_processed(facts.size());
 
 		if (workingAttributes != null)
 			for (String attr : workingAttributes) {
@@ -158,44 +159,15 @@
 		
 		//FactTargetDistribution stats = dt.getDistribution(facts);
 		
-		FactTargetDistribution stats = new FactTargetDistribution(dt.getDomain(dt.getTarget()));
+		FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
 		stats.calculateDistribution(facts);
-	
 		stats.evaluateMajority();
-//		
-//		Object winner1 = stats.getThe_winner_target_class();
-//		for (Object looser: stats.getTargetClasses()) {
-//			System.out.println(" the target class = "+ looser);
-//			if (!winner1.equals(looser) && stats.getVoteFor(looser)>0) {
-//				System.out.println(" the num of supporters = "+ stats.getVoteFor(looser));
-//				System.out.println(" but the guys "+ stats.getSupportersFor(looser));
-//				System.out.println("How many bok: "+stats.getSupportersFor(looser).size());
-//				//unclassified_facts.addAll(stats.getSupportersFor(looser));
-//			} else
-//				System.out.println(Util.ntimes("DANIEL", 5)+ "how many times not matching?? not a looser "+ looser );
-//		}
-		/*
-		Collection<Object> targetValues = stats.keySet();
-		int winner_vote = 0;
-		int num_supporters = 0;
-		Object winner = null;
-		for (Object key : targetValues) {
 
-			int num_in_class = stats.get(key).intValue();
-			if (num_in_class > 0)
-				num_supporters++;
-			if (num_in_class > winner_vote) {
-				winner_vote = num_in_class;
-				winner = key;
-			}
-		}
-		*
-
 		/* if all elements are classified to the same value */
 		if (stats.getNum_supported_target_classes() == 1) {
 
 			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
-			classifiedNode.setRank((double) facts.size()/(double) num_fact_processed);
+			classifiedNode.setRank((double) facts.size()/(double) getNum_fact_processed());
 			classifiedNode.setNumSupporter(facts.size());
 			
 			return classifiedNode;

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -12,4 +12,6 @@
 
 	DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
 
+	int getNum_fact_processed();
+	void setNum_fact_processed(int num);
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -14,7 +14,9 @@
 import dt.TreeNode;
 import dt.memory.Domain;
 import dt.memory.Fact;
+import dt.memory.FactDistribution;
 import dt.memory.FactSet;
+import dt.memory.FactTargetDistribution;
 import dt.memory.OOFactSet;
 import dt.memory.WorkingMemory;
 import dt.tools.Util;
@@ -197,27 +199,15 @@
 			throw new RuntimeException("Nothing to classify, factlist is empty");
 		}
 		/* let's get the statistics of the results */
-		List<?> targetValues = dt.getPossibleValues(dt.getTarget());	
-		Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//,targetValues
+		FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+		stats.calculateDistribution(facts);
+		
+		stats.evaluateMajority();
 
-		int winner_vote = 0;
-		int num_supporters = 0;
-		Object winner = null;		
-		for (Object key: targetValues) {
-
-			int num_in_class = stats.get(key).intValue();
-			if (num_in_class>0)
-				num_supporters ++;
-			if (num_in_class > winner_vote) {
-				winner_vote = num_in_class;
-				winner = key;
-			}
-		}
-
 		/* if all elements are classified to the same value */
-		if (num_supporters == 1) {
+		if (stats.getNum_supported_target_classes() == 1) {
 			//*OPT*			return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
-			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
 			classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
 			classifiedNode.setNumSupporter(facts.size());
 			return classifiedNode;
@@ -226,9 +216,10 @@
 		/* if  there is no attribute left in order to continue */
 		if (attributeNames.size() == 0) {
 			/* an heuristic of the leaf classification*/
+			Object winner = stats.getThe_winner_target_class();
 			LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
-			noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
-			noAttributeLeftNode.setNumSupporter(winner_vote);
+			noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_processed);
+			noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
 			return noAttributeLeftNode;
 		}
 
@@ -259,7 +250,7 @@
 
 			if (filtered_facts.get(value).isEmpty()) {
 				/* majority !!!! */
-				LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+				LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
 				majorityNode.setRank(0.0);
 				currentNode.addNode(value, majorityNode);
 			} else {

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -4,13 +4,13 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.List;
 
 import dt.DecisionTree;
 import dt.memory.Domain;
 import dt.memory.Fact;
+import dt.memory.FactAttrDistribution;
 import dt.memory.FactDistribution;
 import dt.memory.FactTargetDistribution;
 import dt.tools.Util;
@@ -18,7 +18,7 @@
 public class Entropy implements InformationMeasure {
 	
 	public static Domain<?> chooseContAttribute(DecisionTree dt, List<Fact> facts,
-			FactTargetDistribution facts_in_class, List<String> attrs) {
+			FactDistribution facts_in_class, List<String> attrs) {
 
 		double dt_info = calc_info(facts_in_class);
 		double greatestGain = -100000.0;
@@ -102,7 +102,7 @@
 		keys.add(key1);
 		
 		
-		FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
+		FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
 		facts_at_attribute.setTotal(facts.size());
 		facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
 		facts_at_attribute.setSumForAttr(key1, facts.size());
@@ -212,7 +212,7 @@
 		keys.add(key1);
 		
 		
-		FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
+		FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
 		facts_at_attribute.setTotal(facts.size());
 		facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
 		facts_at_attribute.setSumForAttr(key1, facts.size());
@@ -220,7 +220,7 @@
 		double best_sum = -100000.0;
 		Object value_to_split = splitValues.get(0);
 		int split_index =1, index = 1;
-		FactDistribution best_distribution;
+		FactAttrDistribution best_distribution = null;
 		Iterator<Fact> f_ite = facts.iterator();
 		Fact f1 = f_ite.next();
 		Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
@@ -283,11 +283,11 @@
 			List<Integer> split_indices,
 			List<Fact> split_facts)
 		 */
-//		info_contattr_rec(facts.subList(0, split_index),
-//				splitDomain, targetDomain, 
-//				best_distribution.getAttrFor(key0), 
-//				split_indices,
-//				split_facts);
+		info_contattr_rec(facts.subList(0, split_index),
+				splitDomain, targetDomain, 
+				best_distribution.getAttrFor(key0), 
+				split_indices,
+				split_facts);
 		
 		
 		if (Util.DEBUG) {
@@ -325,124 +325,14 @@
 	 * instances of a single class or (b) some stopping criterion is reached. I
 	 * can't remember what stopping criteria they used.
 	 */
-	public static double info_contattr_old (List<Fact> facts,
-			Domain splitDomain, Domain<?> targetDomain, 
-			Hashtable<Object, Integer> facts_in_class, 
-			List<Integer> split_indices,
-			List<Fact> split_facts) {
-	
-		String splitAttr = splitDomain.getName();
-		List<?> splitValues = splitDomain.getValues();
-		String targetAttr = targetDomain.getName();
-		List<?> targetValues = targetDomain.getValues();
-		if (Util.DEBUG) {
-			System.out.println("entropy.info_cont() attributeToSplit? " + splitAttr);
-			int f_i=0;
-			for(Fact f: facts) {
-				System.out.println("entropy.info_cont() SORTING: "+f_i+" attr "+splitAttr+ " "+ f );
-				f_i++;
-			}
-		}
 
-		if (facts.size() <= 1) {
-			System.out
-					.println("The size of the fact list is 0 oups??? exiting....");
-			System.exit(0);
-		}
-		if (split_facts.size() < 1) {
-			System.out
-					.println("The size of the splits is 0 oups??? exiting....");
-			System.exit(0);
-		}
-		
-		/* initialize the distribution */
-		Object key0 = Integer.valueOf(0);
-		Object key1 = Integer.valueOf(1);
-		List<Object> keys = new ArrayList<Object>(2);
-		keys.add(key0);
-		keys.add(key1);
-		
-		
-		FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
-		facts_at_attribute.setTotal(facts.size());
-		facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
-		facts_at_attribute.setSumForAttr(key1, facts.size());
-		
-		double best_sum = -100000.0;
-		Object value_to_split = splitValues.get(0);
-		int split_index =1, index = 1;
-		Iterator<Fact> f_ite = facts.iterator();
-		Fact f1 = f_ite.next();
-		Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
-		if (Util.DEBUG)	System.out.println("\nentropy.info_cont() SEARCHING: "+split_index+" attr "+splitAttr+ " "+ f1 );
-		while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
-
-			Fact f2 = f_ite.next();
-			if (Util.DEBUG) System.out.print("entropy.info_cont() SEARCHING: "+(index+1)+" attr "+splitAttr+ " "+ f2 );
-			Object targetKey = f2.getFieldValue(targetAttr);
-			
-			// System.out.println("My key: "+ targetKey.toString());
-			//for (Object attr_key : attr_values)
-			
-			/* every time it change the place in the distribution */
-			facts_at_attribute.change(key0, targetKey, +1);
-			facts_at_attribute.change(key1, targetKey, -1);
-	
-			/*
-			 * 2.1 Cut points are points in the sorted list above where the class labels change. 
-			 * Eg. if I had five instances with values for the attribute of interest and labels 
-			 * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
-			 * two cutpoints of interest: 1.85 and 5 (mid-way between the points
-			 * where the classes change from A to B or vice versa).
-			 */
-			
-			if ( targetComp.compare(f1, f2)!=0) {
-				// the cut point
-				Number cp_i = (Number) f1.getFieldValue(splitAttr);
-				Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
-
-				Number cut_point = (Double)(cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
-				
-				/*
-				 * 3. Evaluate your favourite disparity measure 
-				 * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
-				 * and calculate its gain 
-				 */
-				double sum = calc_info_attr(facts_at_attribute);
-				//System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum + 
-				if (Util.DEBUG) System.out.println("  **Try "+ sum + " best sum "+best_sum + 
-				" value ("+ f1.getFieldValue(splitAttr) +"-|"+ value_to_split+"|-"+ f2.getFieldValue(splitAttr)+")");
-				
-				if (sum > best_sum) {
-					best_sum = sum;
-					value_to_split = cut_point;
-					if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
-					split_index = index;
-				}
-			} else {}		
-			f1 = f2;
-			index++;
-		}
-		splitDomain.addPseudoValue(value_to_split);
-		Util.insert(split_indices, Integer.valueOf(split_index));
-		if (Util.DEBUG) {
-			System.out.println("entropy.info_contattr(BOK_last) split_indices.size "+split_indices.size());
-			for(Integer i : split_indices)
-				System.out.println("entropy.info_contattr(FOUNDS) split_indices "+i + " the fact "+facts.get(i));
-			System.out.println("entropy.chooseContAttribute(1.5)*********** num of split for "+
-					splitAttr+": "+ splitDomain.getValues().size());
-		}
-		return best_sum;
-	}
-	
-	
 	/* 
 	 * id3 uses that function because it can not classify continuous attributes
 	 */
 	public static String chooseAttribute(DecisionTree dt, List<Fact> facts,
-			Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+			FactDistribution facts_in_class, List<String> attrs) {
 
-		double dt_info = calc_info(facts_in_class, facts.size());
+		double dt_info = calc_info(facts_in_class);//, facts.size()
 		double greatestGain = -1000;
 		String attributeWithGreatestGain = attrs.get(0);
 		String target = dt.getTarget();
@@ -467,17 +357,16 @@
 		return attributeWithGreatestGain;
 	}
 	
-	public static double info_attr(List<Fact> facts, 
-			 Domain<?> splitDomain, Domain<?> targetDomain) {
+	public static double info_attr(List<Fact> facts, Domain<?> splitDomain, Domain<?> targetDomain) {
 		String attributeToSplit = splitDomain.getName();
 		List<?> attributeValues = splitDomain.getValues();
 		String target = targetDomain.getName();
-		List<?> targetValues = targetDomain.getValues();
+		//List<?> targetValues = targetDomain.getValues();
 		
 		if (Util.DEBUG) System.out.println("What is the attributeToSplit? " + attributeToSplit);
 
 		/* initialize the hashtable */
-		FactDistribution facts_at_attribute = new FactDistribution(attributeValues, targetValues);
+		FactAttrDistribution facts_at_attribute = new FactAttrDistribution(attributeValues, targetDomain);
 		facts_at_attribute.setTotal(facts.size());
 		
 		for (Fact f : facts) {
@@ -499,7 +388,7 @@
 	/*
 	 * for both 
 	 */
-	private static double calc_info_attr( FactDistribution facts_of_attribute) {
+	private static double calc_info_attr( FactAttrDistribution facts_of_attribute) {
 		Collection<Object> attributeValues = facts_of_attribute.getAttributes();
 		int fact_size = facts_of_attribute.getTotal();
 		double sum = 0.0;
@@ -508,37 +397,24 @@
 			//double sum_attr = 0.0;
 			if (total_num_attr > 0) {
 				sum += ((double) total_num_attr / (double) fact_size) * 
-					calc_info(facts_of_attribute.getAttrFor(attr), total_num_attr);
+					calc_info(facts_of_attribute.getAttrFor(attr));
 			}
 		}
 		return sum;
 	}
 
-	/*
+	/* you can calculate this before */
+	/**
+	 * it returns the information value of facts entropy that characterizes the
+	 * (im)purity of an arbitrary collection of examples
 	 * 
+	 * @param facts
+	 *            list of facts
 	 */
-	public static double calc_info(Hashtable<Object, Integer> facts_in_class,
-			int total_num_facts) {
-		Collection<Object> targetValues = facts_in_class.keySet();
-		double prob, sum = 0;
-		for (Object key : targetValues) {
-			int num_in_class = facts_in_class.get(key).intValue();
-			// System.out.println("num_in_class : "+ num_in_class + " key "+ key+ " and the total num "+ total_num_facts);
-			
-			if (num_in_class > 0) {
-				prob = (double) num_in_class / (double) total_num_facts;
-				/* TODO what if it is a sooo small number ???? */
-				sum +=  -1 * prob * Util.log2(prob);
-			// System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"where the sum: "+sum);
-			}
-		}
-		return sum;
-	}
-	/* you can calculate this before */
 	public static double calc_info(FactTargetDistribution facts_in_class) {
 		
 		int total_num_facts = facts_in_class.getSum();
-		Collection<Object> targetValues = facts_in_class.getTargetClasses();
+		Collection<?> targetValues = facts_in_class.getTargetClasses();
 		double prob, sum = 0;
 		for (Object key : targetValues) {
 			int num_in_class = facts_in_class.getVoteFor(key);

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -12,6 +12,8 @@
 import dt.LeafNode;
 import dt.TreeNode;
 
+import dt.memory.FactDistribution;
+import dt.memory.FactTargetDistribution;
 import dt.memory.WorkingMemory;
 import dt.memory.Fact;
 import dt.memory.FactSet;
@@ -72,7 +74,7 @@
 		}
 		dt.FACTS_READ += facts.size();
 		
-		num_fact_processed = facts.size();
+		setNum_fact_processed(facts.size());
 			
 		if (workingAttributes != null)
 			for (String attr: workingAttributes) {
@@ -112,7 +114,7 @@
 			}
 		}
 		dt.FACTS_READ += facts.size();
-		num_fact_processed = facts.size(); 
+		setNum_fact_processed(facts.size());
 			
 		if (workingAttributes != null)
 			for (String attr: workingAttributes) {
@@ -143,27 +145,15 @@
 		}
 		/* let's get the statistics of the results */
 		//List<?> targetValues = dt.getPossibleValues(dt.getTarget());	
-		Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
-		Collection<Object> targetValues = stats.keySet();
+		FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+		stats.calculateDistribution(facts);
 		
-		int winner_vote = 0;
-		int num_supporters = 0;
-		Object winner = null;		
-		for (Object key: targetValues) {
+		stats.evaluateMajority();
 
-			int num_in_class = stats.get(key).intValue();
-			if (num_in_class>0)
-				num_supporters ++;
-			if (num_in_class > winner_vote) {
-				winner_vote = num_in_class;
-				winner = key;
-			}
-		}
-
 		/* if all elements are classified to the same value */
-		if (num_supporters == 1) {
+		if (stats.getNum_supported_target_classes() == 1) {
 			//*OPT*			return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
-			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
 			classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
 			classifiedNode.setNumSupporter(facts.size());
 			return classifiedNode;
@@ -172,9 +162,10 @@
 		/* if  there is no attribute left in order to continue */
 		if (attributeNames.size() == 0) {
 			/* an heuristic of the leaf classification*/
+			Object winner = stats.getThe_winner_target_class();
 			LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
-			noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
-			noAttributeLeftNode.setNumSupporter(winner_vote);
+			noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_processed);
+			noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
 			return noAttributeLeftNode;
 		}
 
@@ -206,7 +197,7 @@
 			
 			if (filtered_facts.get(value).isEmpty()) {
 				/* majority !!!! */
-				LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+				LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
 				majorityNode.setRank(-1.0);
 				majorityNode.setNumSupporter(filtered_facts.get(value).size());
 				currentNode.addNode(value, majorityNode);
@@ -222,4 +213,12 @@
 	public int getNumCall() {
 		return FUNC_CALL;
 	}
+
+
+	public int getNum_fact_processed() {
+		return num_fact_processed;
+	}
+	public void setNum_fact_processed(int num_fact_processed) {
+		this.num_fact_processed = num_fact_processed;
+	}
 }

Copied: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java (from rev 19371, labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java)
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -0,0 +1,86 @@
+package dt.memory;
+
+import java.util.Collection;
+import java.util.Hashtable;
+import java.util.List;
+
+import dt.tools.Util;
+
+public class FactAttrDistribution {
+	String attr_sum = Util.sum();
+	
+	Hashtable<Object, FactTargetDistribution> facts_at_attr;
+
+	private int total_num;
+	
+	public FactAttrDistribution(List<?> attributeValues, Domain<?>  targetDomain) {
+		facts_at_attr = new Hashtable<Object, FactTargetDistribution>(attributeValues.size());
+		
+		for (Object attr : attributeValues) {
+			facts_at_attr.put(attr, new FactTargetDistribution(targetDomain));
+//			for (Object t : targetDomain.getValues()) {
+//				facts_at_attr.get(attr).put(t, 0);
+//			}
+//			facts_at_attr.get(attr).put(attr_sum, 0);
+		}
+		
+	}
+	
+	public FactAttrDistribution clone() {
+		return this.clone();
+	}
+	
+	public void setTotal(int size) {
+		this.total_num = size;	
+	}
+	public int getTotal() {
+		return this.total_num;	
+	}
+	
+	public FactTargetDistribution getAttrFor(Object attr_value) {
+		return facts_at_attr.get(attr_value);
+	}
+	
+	public int getSumForAttr(Object attr_value) {
+		return facts_at_attr.get(attr_value).getSum();
+	}
+	
+	public void setSumForAttr(Object attr_value, int total) {
+		facts_at_attr.get(attr_value).setSum(total);
+	}
+	
+//	public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
+//		for (Object target: targetDist.keySet())
+//			facts_at_attr.get(attr_value).put(target,targetDist.get(target));
+//	}
+	
+	public void setTargetDistForAttr(Object attr_value, FactTargetDistribution targetDist) {
+		
+		//facts_at_attr.put(attr_value, targetDist);
+		/* TODO should i make a close */
+		FactTargetDistribution old = facts_at_attr.get(attr_value);		
+		old.setDistribution(targetDist);
+	
+	}
+
+	public void change(Object attrValue, Object targetValue, int i) {
+		facts_at_attr.get(attrValue).change(targetValue, i);
+		
+		facts_at_attr.get(attrValue).change(attr_sum, i);
+/*		int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
+		num_1 += i;
+		facts_at_attr.get(attrValue).put(targetValue, num_1);
+		
+		int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
+		total_num_1 += i;
+		facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
+		*/
+		
+	}
+
+	public Collection<Object> getAttributes() {
+		return facts_at_attr.keySet();
+	}
+	
+
+}

Deleted: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -1,78 +0,0 @@
-package dt.memory;
-
-import java.util.Collection;
-import java.util.Hashtable;
-import java.util.List;
-
-import dt.tools.Util;
-
-public class FactDistribution {
-	String attr_sum = Util.sum();
-	
-	Hashtable<Object, Hashtable<Object, Integer>> facts_at_attr;
-
-	private int total_num;
-	
-	public FactDistribution(List<?> attributeValues, List<?>  targetValues) {
-		facts_at_attr = new Hashtable<Object, Hashtable<Object, Integer>>(attributeValues.size());
-		
-		for (Object attr : attributeValues) {
-			facts_at_attr.put(attr, new Hashtable<Object, Integer>(targetValues.size() + 1));
-			for (Object t : targetValues) {
-				facts_at_attr.get(attr).put(t, 0);
-			}
-			facts_at_attr.get(attr).put(attr_sum, 0);
-		}
-		
-	}
-	
-	public FactDistribution clone() {
-		return this.clone();
-	}
-	
-	public void setTotal(int size) {
-		this.total_num = size;	
-	}
-	public int getTotal() {
-		return this.total_num;	
-	}
-	
-	public Hashtable<Object, Integer> getAttrFor(Object attr_value) {
-		return facts_at_attr.get(attr_value);
-	}
-	
-	public int getSumForAttr(Object attr_value) {
-		return facts_at_attr.get(attr_value).get(attr_sum).intValue();
-	}
-	
-	public void setSumForAttr(Object attr_value, int total) {
-		facts_at_attr.get(attr_value).put(attr_sum,total);
-	}
-	
-	public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
-		for (Object target: targetDist.keySet())
-			facts_at_attr.get(attr_value).put(target,targetDist.get(target));
-	}
-	
-	public void setTargetDistForAttr(Object attr_value, FactTargetDistribution targetDist) {
-		for (Object target: targetDist.getTargetClasses())
-			facts_at_attr.get(attr_value).put(target,targetDist.getVoteFor(target));
-	}
-
-	public void change(Object attrValue, Object targetValue, int i) {
-		int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
-		num_1 += i;
-		facts_at_attr.get(attrValue).put(targetValue, num_1);
-		
-		int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
-		total_num_1 += i;
-		facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
-		
-	}
-
-	public Collection<Object> getAttributes() {
-		return facts_at_attr.keySet();
-	}
-	
-
-}

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -1,107 +1,65 @@
 package dt.memory;
 
-import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Hashtable;
 import java.util.List;
 
 import dt.tools.Util;
 
+/* simple histogram keeps the number of facts in each class of target */
 public class FactTargetDistribution {
 	
 	private String attr_sum = Util.sum();
-	private Domain<?> targetDomain;
+	protected Domain<?> targetDomain;
 	private Hashtable<Object, Integer> num_at_target;
-	private Hashtable<Object, List<Fact>> facts_at_target;
 	
-	private int num_supported_target_classes;
-	private Object the_winner_target_class;
-	
 	public FactTargetDistribution(Domain<?> targetDomain) {
-		
-//		this.targetDomain = targetDomain.clone();
-//		targetDomain.
-		
-		num_supported_target_classes = 0;
+
 		this.targetDomain = targetDomain;
 		List<?> targetValues = targetDomain.getValues();
-		num_at_target =  new Hashtable<Object, Integer>(targetValues.size() + 1);
-		facts_at_target = new Hashtable<Object, List<Fact>>(targetValues.size());
+		num_at_target =  new Hashtable<Object, Integer>(targetValues.size() + 1);		
 		for (Object t : targetValues) {
-			num_at_target.put(t, 0);
-			facts_at_target.put(t, new ArrayList<Fact>());
+			num_at_target.put(t, 0);			
 		}
 		num_at_target.put(attr_sum, 0);
 		
 	}
 	
-	public void calculateDistribution(List<Fact> facts){
-		int total_num_facts = 0;
-		String target = targetDomain.getName();
-		for (Fact f : facts) {
-			total_num_facts++;
-			Object key = f.getFieldValue(target);
-			// System.out.println("My key: "+ key.toString());
-			num_at_target.put(key, num_at_target.get(key).intValue() + 1); // bocuk
-			facts_at_target.get(key).add(f);
-
-		}
-		num_at_target.put(attr_sum, num_at_target.get(attr_sum).intValue() + total_num_facts);
-		
+	public Collection<?> getTargetClasses() {
+		return targetDomain.getValues();
 	}
-	public Collection<Object> getTargetClasses() {
-		return facts_at_target.keySet();
-	}
+	
 	public int getSum() {
 		return num_at_target.get(attr_sum).intValue();
 	}
+	public void setSum(int sum) {
+		num_at_target.put(attr_sum, sum);
+	}
 	
 	public int getVoteFor(Object value) {
 		return num_at_target.get(value).intValue();
 	}
 	
-	public List<Fact> getSupportersFor(Object value) {
-		return facts_at_target.get(value);
+	public Domain<?> getTargetDomain() {
+		return targetDomain;
 	}
-	public void evaluateMajority() {
-		
-		List<?> targetValues = targetDomain.getValues();
-		int winner_vote = 0;
-		int num_supporters = 0;
-		
-		Object winner = null;
-		for (Object key : targetValues) {
 
-			int num_in_class = num_at_target.get(key).intValue();
-			if (num_in_class > 0)
-				num_supporters++;
-			if (num_in_class > winner_vote) {
-				winner_vote = num_in_class;
-				winner = key;
-			}
-		}
-		setNum_supperted_target_classes(num_supporters);
-		setThe_winner_target_class(winner);
-		
+	public void setTargetDomain(Domain<?> targetDomain) {
+		this.targetDomain = targetDomain.clone();
 	}
-
-	public int getNum_supported_target_classes() {
-		return num_supported_target_classes;
+	
+	public void change(Object targetValue, int i) {
+		int num_1 = num_at_target.get(targetValue).intValue();
+		num_1 += i;
+		num_at_target.put(targetValue, num_1);
 	}
 
-	public void setNum_supperted_target_classes(int num_supperted_target_classes) {
-		this.num_supported_target_classes = num_supperted_target_classes;
+	public void setDistribution(FactTargetDistribution targetDist) {
+		for (Object targetValue: targetDomain.getValues()) {
+			num_at_target.put(targetValue, targetDist.getVoteFor(targetValue));
+		}
+		
 	}
-
-	public Object getThe_winner_target_class() {
-		return the_winner_target_class;
-	}
-
-	public void setThe_winner_target_class(Object the_winner_target_class) {
-		this.the_winner_target_class = the_winner_target_class;
-	}
 	
-	
-	
 
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -8,6 +8,7 @@
 
 import dt.memory.Domain;
 import dt.memory.Fact;
+import dt.memory.FactDistribution;
 import dt.memory.FactTargetDistribution;
 
 public class FactProcessor {
@@ -101,7 +102,7 @@
 	}
 
 	public static void splitUnclassifiedFacts(
-			List<Fact> unclassified_facts, FactTargetDistribution stats) {
+			List<Fact> unclassified_facts, FactDistribution stats) {
 		
 		Object winner = stats.getThe_winner_target_class();
 		System.out.println(Util.ntimes("DANIEL", 2)+ " lets get unclassified daniel winner "+winner +" num of sup "  +stats.getVoteFor(winner));

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -22,7 +22,7 @@
 			dt = System.currentTimeMillis() - dt;
 			System.out.println("Time" + dt + "\n" + bocuksTree);
 
-			RulePrinter my_printer = new RulePrinter();
+			RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
 			boolean sort_via_rank = true;
 			my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
 			

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -39,14 +39,13 @@
 		this.num_facts = num_facts;
 	}
 	
-	public RulePrinter() {
-		ruleText = new ArrayList<String>();
-		//rule_list = new ArrayList<ArrayList<NodeValue>>();
-		rules = new ArrayList<Rule>();
-		
-		/* most important */
-		nodes = new Stack<NodeValue>();
+	public int getNum_facts() {
+		return num_facts;
 	}
+
+	public void setNum_facts(int num_facts) {
+		this.num_facts = num_facts;
+	}
 	
 	public void printer(DecisionTree dt, String packageName, String outputFile, boolean sort) {//, PrintStream object
 		ruleObject = dt.getName();
@@ -68,11 +67,12 @@
 			Collections.sort(rules, Rule.getRankComparator());
 		
 		int total_num_facts=0;
-		int i = 0;
+		int i = 0, active_i = 0;
 		for( Rule rule: rules) {
 			i++;
 			if (ONLY_ACTIVE) {
 				if (rule.getRank() >= 0) {
+					active_i++;
 					System.out.println("//Active rules " +i + " write to drl \n"+ rule +"\n");
 					if (outputFile!=null) {
 						write(rule.toString(), true, outputFile);
@@ -81,16 +81,20 @@
 				}
 
 			} else {
+				if (rule.getRank() >= 0) {
+					active_i++;
+				}
 				System.out.println("//rule " +i + " write to drl \n"+ rule +"\n");
 				if (outputFile!=null) {
 					write(rule.toString(), true, outputFile);
 					write("\n", true, outputFile);
 				}
 			}
-			total_num_facts += rule.getPopularity();
+			total_num_facts += rule.getPopularity();		
 		}
 		if (outputFile!=null) {
-			write("//THE END: Total number of facts correctly classified= "+ total_num_facts, true, outputFile);
+			write("//THE END: Total number of facts correctly classified= "+ total_num_facts + " over "+ getNum_facts() , true, outputFile);
+			write("\n//with " + active_i + " number of rules over "+i+" total number of rules ", true, outputFile);
 			write("\n", true, outputFile); // EOF
 		}
 	}

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -39,7 +39,7 @@
 			System.out.println("Time"+dt + " facts read: "+bocuksTree.getNumRead() + " num call: "+ bocuk.getNumCall() );
 			//System.out.println(bocuksTree);
 
-			RulePrinter my_printer = new RulePrinter();
+			RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
 			boolean sort_via_rank = true;
 			my_printer.printer(bocuksTree, null, null, sort_via_rank);
 		}

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java	2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java	2008-04-02 01:29:16 UTC (rev 19372)
@@ -45,7 +45,7 @@
 		dt = System.currentTimeMillis() - dt;
 		System.out.println("Time"+dt+"\n"+bocuksTree);
 		
-		RulePrinter my_printer = new RulePrinter();
+		RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
 		boolean sort_via_rank = true;
 		my_printer.printer(bocuksTree,"test" , new String("../dt_learning/src/test/rules"+".drl"), sort_via_rank);
 	}




More information about the jboss-svn-commits mailing list