[jboss-svn-commits] JBL Code SVN: r19372 - in labs/jbossrules/contrib/machinelearning/decisiontree/src: dt/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Tue Apr 1 21:29:16 EDT 2008
Author: gizil
Date: 2008-04-01 21:29:16 -0400 (Tue, 01 Apr 2008)
New Revision: 19372
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java
Removed:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
Log:
optimizing before recursive discretization
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -5,9 +5,6 @@
import java.util.List;
import dt.memory.Domain;
-import dt.memory.Fact;
-import dt.memory.FactTargetDistribution;
-import dt.tools.Util;
public class DecisionTree {
@@ -33,93 +30,7 @@
this.attrsToClassify = new ArrayList<String>();
}
- public Object getMajority(List<Fact> facts) {
- List<?> targetValues = getPossibleValues(this.target);
- Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target);
- int winner_vote = 0;
- Object winner = null;
- for (Object key : targetValues) {
-
- int num_in_class = facts_in_class.get(key).intValue();
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
- return winner;
- }
-
- // *OPT* public double getInformation(List<FactSet> facts) {
- public Hashtable<Object, Integer> getStatistics(List<Fact> facts, String target) {
-
- List<?> targetValues = getPossibleValues(this.target);
- Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(
- targetValues.size());
-
- for (Object t : targetValues) {
- facts_in_class.put(t, 0);
- }
-
- int total_num_facts = 0;
- // *OPT* for (FactSet fs: facts) {
- // *OPT* for (Fact f: fs.getFacts()) {
- for (Fact f : facts) {
- total_num_facts++;
- Object key = f.getFieldValue(target);
- // System.out.println("My key: "+ key.toString());
- facts_in_class.put(key, facts_in_class.get(key).intValue() + 1); // bocuk
- // kafa
- // :P
- }
- FACTS_READ += facts.size();
- // *OPT* }
- // *OPT* }
- return facts_in_class;
- }
-
- // *OPT* public double getInformation(List<FactSet> facts) {
- public FactTargetDistribution getDistribution(List<Fact> facts) {
-
- FactTargetDistribution facts_in_class = new FactTargetDistribution(getDomain(getTarget()));
- facts_in_class.calculateDistribution(facts);
- FACTS_READ += facts.size();
- return facts_in_class;
- }
-
- // *OPT* public double getInformation(List<FactSet> facts) {
- /**
- * it returns the information value of facts entropy that characterizes the
- * (im)purity of an arbitrary collection of examples
- *
- * @param facts
- * list of facts
- */
- public double getInformation_old(List<Fact> facts) {
-
- List<?> targetValues = getPossibleValues(this.target);
- Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
- getTarget()); // , targetValues)
- // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
- // getTarget(), targetValues);
- int total_num_facts = facts.size();
- double sum = 0;
- for (Object key : targetValues) {
- int num_in_class = facts_in_class.get(key).intValue();
- // System.out.println("num_in_class : "+ num_in_class + " key "+ key
- // + " and the total num "+ total_num_facts);
- double prob = (double) num_in_class / (double) total_num_facts;
-
- // double log2= Util.log2(prob);
- // double plog2p= prob*log2;
- sum += (prob == 0.0) ? 0.0 : -1 * prob * Util.log2(prob);
- // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
- // where the sum: "+sum);
- }
- return sum;
- }
-
-
public void setTarget(String targetField) {
target = targetField;
attrsToClassify.remove(target);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -11,6 +11,7 @@
import dt.LeafNode;
import dt.TreeNode;
+import dt.memory.FactDistribution;
import dt.memory.FactTargetDistribution;
import dt.memory.WorkingMemory;
import dt.memory.Fact;
@@ -83,7 +84,7 @@
}
dt.FACTS_READ += facts.size();
- num_fact_processed = facts.size();
+ setNum_fact_processed(facts.size());
if (workingAttributes != null)
for (String attr : workingAttributes) {
@@ -123,7 +124,7 @@
}
}
dt.FACTS_READ += facts.size();
- num_fact_processed = facts.size();
+ setNum_fact_processed(facts.size());
if (workingAttributes != null)
for (String attr : workingAttributes) {
@@ -158,44 +159,15 @@
//FactTargetDistribution stats = dt.getDistribution(facts);
- FactTargetDistribution stats = new FactTargetDistribution(dt.getDomain(dt.getTarget()));
+ FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
stats.calculateDistribution(facts);
-
stats.evaluateMajority();
-//
-// Object winner1 = stats.getThe_winner_target_class();
-// for (Object looser: stats.getTargetClasses()) {
-// System.out.println(" the target class = "+ looser);
-// if (!winner1.equals(looser) && stats.getVoteFor(looser)>0) {
-// System.out.println(" the num of supporters = "+ stats.getVoteFor(looser));
-// System.out.println(" but the guys "+ stats.getSupportersFor(looser));
-// System.out.println("How many bok: "+stats.getSupportersFor(looser).size());
-// //unclassified_facts.addAll(stats.getSupportersFor(looser));
-// } else
-// System.out.println(Util.ntimes("DANIEL", 5)+ "how many times not matching?? not a looser "+ looser );
-// }
- /*
- Collection<Object> targetValues = stats.keySet();
- int winner_vote = 0;
- int num_supporters = 0;
- Object winner = null;
- for (Object key : targetValues) {
- int num_in_class = stats.get(key).intValue();
- if (num_in_class > 0)
- num_supporters++;
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
- *
-
/* if all elements are classified to the same value */
if (stats.getNum_supported_target_classes() == 1) {
LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
- classifiedNode.setRank((double) facts.size()/(double) num_fact_processed);
+ classifiedNode.setRank((double) facts.size()/(double) getNum_fact_processed());
classifiedNode.setNumSupporter(facts.size());
return classifiedNode;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -12,4 +12,6 @@
DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
+ int getNum_fact_processed();
+ void setNum_fact_processed(int num);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -14,7 +14,9 @@
import dt.TreeNode;
import dt.memory.Domain;
import dt.memory.Fact;
+import dt.memory.FactDistribution;
import dt.memory.FactSet;
+import dt.memory.FactTargetDistribution;
import dt.memory.OOFactSet;
import dt.memory.WorkingMemory;
import dt.tools.Util;
@@ -197,27 +199,15 @@
throw new RuntimeException("Nothing to classify, factlist is empty");
}
/* let's get the statistics of the results */
- List<?> targetValues = dt.getPossibleValues(dt.getTarget());
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//,targetValues
+ FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+ stats.calculateDistribution(facts);
+
+ stats.evaluateMajority();
- int winner_vote = 0;
- int num_supporters = 0;
- Object winner = null;
- for (Object key: targetValues) {
-
- int num_in_class = stats.get(key).intValue();
- if (num_in_class>0)
- num_supporters ++;
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
-
/* if all elements are classified to the same value */
- if (num_supporters == 1) {
+ if (stats.getNum_supported_target_classes() == 1) {
//*OPT* return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
- LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
classifiedNode.setNumSupporter(facts.size());
return classifiedNode;
@@ -226,9 +216,10 @@
/* if there is no attribute left in order to continue */
if (attributeNames.size() == 0) {
/* an heuristic of the leaf classification*/
+ Object winner = stats.getThe_winner_target_class();
LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
- noAttributeLeftNode.setNumSupporter(winner_vote);
+ noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_processed);
+ noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
return noAttributeLeftNode;
}
@@ -259,7 +250,7 @@
if (filtered_facts.get(value).isEmpty()) {
/* majority !!!! */
- LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
majorityNode.setRank(0.0);
currentNode.addNode(value, majorityNode);
} else {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -4,13 +4,13 @@
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
-import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import dt.DecisionTree;
import dt.memory.Domain;
import dt.memory.Fact;
+import dt.memory.FactAttrDistribution;
import dt.memory.FactDistribution;
import dt.memory.FactTargetDistribution;
import dt.tools.Util;
@@ -18,7 +18,7 @@
public class Entropy implements InformationMeasure {
public static Domain<?> chooseContAttribute(DecisionTree dt, List<Fact> facts,
- FactTargetDistribution facts_in_class, List<String> attrs) {
+ FactDistribution facts_in_class, List<String> attrs) {
double dt_info = calc_info(facts_in_class);
double greatestGain = -100000.0;
@@ -102,7 +102,7 @@
keys.add(key1);
- FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
facts_at_attribute.setTotal(facts.size());
facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
facts_at_attribute.setSumForAttr(key1, facts.size());
@@ -212,7 +212,7 @@
keys.add(key1);
- FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
facts_at_attribute.setTotal(facts.size());
facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
facts_at_attribute.setSumForAttr(key1, facts.size());
@@ -220,7 +220,7 @@
double best_sum = -100000.0;
Object value_to_split = splitValues.get(0);
int split_index =1, index = 1;
- FactDistribution best_distribution;
+ FactAttrDistribution best_distribution = null;
Iterator<Fact> f_ite = facts.iterator();
Fact f1 = f_ite.next();
Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
@@ -283,11 +283,11 @@
List<Integer> split_indices,
List<Fact> split_facts)
*/
-// info_contattr_rec(facts.subList(0, split_index),
-// splitDomain, targetDomain,
-// best_distribution.getAttrFor(key0),
-// split_indices,
-// split_facts);
+ info_contattr_rec(facts.subList(0, split_index),
+ splitDomain, targetDomain,
+ best_distribution.getAttrFor(key0),
+ split_indices,
+ split_facts);
if (Util.DEBUG) {
@@ -325,124 +325,14 @@
* instances of a single class or (b) some stopping criterion is reached. I
* can't remember what stopping criteria they used.
*/
- public static double info_contattr_old (List<Fact> facts,
- Domain splitDomain, Domain<?> targetDomain,
- Hashtable<Object, Integer> facts_in_class,
- List<Integer> split_indices,
- List<Fact> split_facts) {
-
- String splitAttr = splitDomain.getName();
- List<?> splitValues = splitDomain.getValues();
- String targetAttr = targetDomain.getName();
- List<?> targetValues = targetDomain.getValues();
- if (Util.DEBUG) {
- System.out.println("entropy.info_cont() attributeToSplit? " + splitAttr);
- int f_i=0;
- for(Fact f: facts) {
- System.out.println("entropy.info_cont() SORTING: "+f_i+" attr "+splitAttr+ " "+ f );
- f_i++;
- }
- }
- if (facts.size() <= 1) {
- System.out
- .println("The size of the fact list is 0 oups??? exiting....");
- System.exit(0);
- }
- if (split_facts.size() < 1) {
- System.out
- .println("The size of the splits is 0 oups??? exiting....");
- System.exit(0);
- }
-
- /* initialize the distribution */
- Object key0 = Integer.valueOf(0);
- Object key1 = Integer.valueOf(1);
- List<Object> keys = new ArrayList<Object>(2);
- keys.add(key0);
- keys.add(key1);
-
-
- FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
- facts_at_attribute.setTotal(facts.size());
- facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
- facts_at_attribute.setSumForAttr(key1, facts.size());
-
- double best_sum = -100000.0;
- Object value_to_split = splitValues.get(0);
- int split_index =1, index = 1;
- Iterator<Fact> f_ite = facts.iterator();
- Fact f1 = f_ite.next();
- Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
- if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+split_index+" attr "+splitAttr+ " "+ f1 );
- while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
-
- Fact f2 = f_ite.next();
- if (Util.DEBUG) System.out.print("entropy.info_cont() SEARCHING: "+(index+1)+" attr "+splitAttr+ " "+ f2 );
- Object targetKey = f2.getFieldValue(targetAttr);
-
- // System.out.println("My key: "+ targetKey.toString());
- //for (Object attr_key : attr_values)
-
- /* every time it change the place in the distribution */
- facts_at_attribute.change(key0, targetKey, +1);
- facts_at_attribute.change(key1, targetKey, -1);
-
- /*
- * 2.1 Cut points are points in the sorted list above where the class labels change.
- * Eg. if I had five instances with values for the attribute of interest and labels
- * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
- * two cutpoints of interest: 1.85 and 5 (mid-way between the points
- * where the classes change from A to B or vice versa).
- */
-
- if ( targetComp.compare(f1, f2)!=0) {
- // the cut point
- Number cp_i = (Number) f1.getFieldValue(splitAttr);
- Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
-
- Number cut_point = (Double)(cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
-
- /*
- * 3. Evaluate your favourite disparity measure
- * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
- * and calculate its gain
- */
- double sum = calc_info_attr(facts_at_attribute);
- //System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum +
- if (Util.DEBUG) System.out.println(" **Try "+ sum + " best sum "+best_sum +
- " value ("+ f1.getFieldValue(splitAttr) +"-|"+ value_to_split+"|-"+ f2.getFieldValue(splitAttr)+")");
-
- if (sum > best_sum) {
- best_sum = sum;
- value_to_split = cut_point;
- if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
- split_index = index;
- }
- } else {}
- f1 = f2;
- index++;
- }
- splitDomain.addPseudoValue(value_to_split);
- Util.insert(split_indices, Integer.valueOf(split_index));
- if (Util.DEBUG) {
- System.out.println("entropy.info_contattr(BOK_last) split_indices.size "+split_indices.size());
- for(Integer i : split_indices)
- System.out.println("entropy.info_contattr(FOUNDS) split_indices "+i + " the fact "+facts.get(i));
- System.out.println("entropy.chooseContAttribute(1.5)*********** num of split for "+
- splitAttr+": "+ splitDomain.getValues().size());
- }
- return best_sum;
- }
-
-
/*
* id3 uses that function because it can not classify continuous attributes
*/
public static String chooseAttribute(DecisionTree dt, List<Fact> facts,
- Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+ FactDistribution facts_in_class, List<String> attrs) {
- double dt_info = calc_info(facts_in_class, facts.size());
+ double dt_info = calc_info(facts_in_class);//, facts.size()
double greatestGain = -1000;
String attributeWithGreatestGain = attrs.get(0);
String target = dt.getTarget();
@@ -467,17 +357,16 @@
return attributeWithGreatestGain;
}
- public static double info_attr(List<Fact> facts,
- Domain<?> splitDomain, Domain<?> targetDomain) {
+ public static double info_attr(List<Fact> facts, Domain<?> splitDomain, Domain<?> targetDomain) {
String attributeToSplit = splitDomain.getName();
List<?> attributeValues = splitDomain.getValues();
String target = targetDomain.getName();
- List<?> targetValues = targetDomain.getValues();
+ //List<?> targetValues = targetDomain.getValues();
if (Util.DEBUG) System.out.println("What is the attributeToSplit? " + attributeToSplit);
/* initialize the hashtable */
- FactDistribution facts_at_attribute = new FactDistribution(attributeValues, targetValues);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(attributeValues, targetDomain);
facts_at_attribute.setTotal(facts.size());
for (Fact f : facts) {
@@ -499,7 +388,7 @@
/*
* for both
*/
- private static double calc_info_attr( FactDistribution facts_of_attribute) {
+ private static double calc_info_attr( FactAttrDistribution facts_of_attribute) {
Collection<Object> attributeValues = facts_of_attribute.getAttributes();
int fact_size = facts_of_attribute.getTotal();
double sum = 0.0;
@@ -508,37 +397,24 @@
//double sum_attr = 0.0;
if (total_num_attr > 0) {
sum += ((double) total_num_attr / (double) fact_size) *
- calc_info(facts_of_attribute.getAttrFor(attr), total_num_attr);
+ calc_info(facts_of_attribute.getAttrFor(attr));
}
}
return sum;
}
- /*
+ /* you can calculate this before */
+ /**
+ * it returns the information value of facts entropy that characterizes the
+ * (im)purity of an arbitrary collection of examples
*
+ * @param facts
+ * list of facts
*/
- public static double calc_info(Hashtable<Object, Integer> facts_in_class,
- int total_num_facts) {
- Collection<Object> targetValues = facts_in_class.keySet();
- double prob, sum = 0;
- for (Object key : targetValues) {
- int num_in_class = facts_in_class.get(key).intValue();
- // System.out.println("num_in_class : "+ num_in_class + " key "+ key+ " and the total num "+ total_num_facts);
-
- if (num_in_class > 0) {
- prob = (double) num_in_class / (double) total_num_facts;
- /* TODO what if it is a sooo small number ???? */
- sum += -1 * prob * Util.log2(prob);
- // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"where the sum: "+sum);
- }
- }
- return sum;
- }
- /* you can calculate this before */
public static double calc_info(FactTargetDistribution facts_in_class) {
int total_num_facts = facts_in_class.getSum();
- Collection<Object> targetValues = facts_in_class.getTargetClasses();
+ Collection<?> targetValues = facts_in_class.getTargetClasses();
double prob, sum = 0;
for (Object key : targetValues) {
int num_in_class = facts_in_class.getVoteFor(key);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -12,6 +12,8 @@
import dt.LeafNode;
import dt.TreeNode;
+import dt.memory.FactDistribution;
+import dt.memory.FactTargetDistribution;
import dt.memory.WorkingMemory;
import dt.memory.Fact;
import dt.memory.FactSet;
@@ -72,7 +74,7 @@
}
dt.FACTS_READ += facts.size();
- num_fact_processed = facts.size();
+ setNum_fact_processed(facts.size());
if (workingAttributes != null)
for (String attr: workingAttributes) {
@@ -112,7 +114,7 @@
}
}
dt.FACTS_READ += facts.size();
- num_fact_processed = facts.size();
+ setNum_fact_processed(facts.size());
if (workingAttributes != null)
for (String attr: workingAttributes) {
@@ -143,27 +145,15 @@
}
/* let's get the statistics of the results */
//List<?> targetValues = dt.getPossibleValues(dt.getTarget());
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
- Collection<Object> targetValues = stats.keySet();
+ FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+ stats.calculateDistribution(facts);
- int winner_vote = 0;
- int num_supporters = 0;
- Object winner = null;
- for (Object key: targetValues) {
+ stats.evaluateMajority();
- int num_in_class = stats.get(key).intValue();
- if (num_in_class>0)
- num_supporters ++;
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
-
/* if all elements are classified to the same value */
- if (num_supporters == 1) {
+ if (stats.getNum_supported_target_classes() == 1) {
//*OPT* return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
- LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
classifiedNode.setNumSupporter(facts.size());
return classifiedNode;
@@ -172,9 +162,10 @@
/* if there is no attribute left in order to continue */
if (attributeNames.size() == 0) {
/* an heuristic of the leaf classification*/
+ Object winner = stats.getThe_winner_target_class();
LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
- noAttributeLeftNode.setNumSupporter(winner_vote);
+ noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_processed);
+ noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
return noAttributeLeftNode;
}
@@ -206,7 +197,7 @@
if (filtered_facts.get(value).isEmpty()) {
/* majority !!!! */
- LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
majorityNode.setRank(-1.0);
majorityNode.setNumSupporter(filtered_facts.get(value).size());
currentNode.addNode(value, majorityNode);
@@ -222,4 +213,12 @@
public int getNumCall() {
return FUNC_CALL;
}
+
+
+ public int getNum_fact_processed() {
+ return num_fact_processed;
+ }
+ public void setNum_fact_processed(int num_fact_processed) {
+ this.num_fact_processed = num_fact_processed;
+ }
}
Copied: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java (from rev 19371, labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java)
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -0,0 +1,86 @@
+package dt.memory;
+
+import java.util.Collection;
+import java.util.Hashtable;
+import java.util.List;
+
+import dt.tools.Util;
+
+public class FactAttrDistribution {
+ String attr_sum = Util.sum();
+
+ Hashtable<Object, FactTargetDistribution> facts_at_attr;
+
+ private int total_num;
+
+ public FactAttrDistribution(List<?> attributeValues, Domain<?> targetDomain) {
+ facts_at_attr = new Hashtable<Object, FactTargetDistribution>(attributeValues.size());
+
+ for (Object attr : attributeValues) {
+ facts_at_attr.put(attr, new FactTargetDistribution(targetDomain));
+// for (Object t : targetDomain.getValues()) {
+// facts_at_attr.get(attr).put(t, 0);
+// }
+// facts_at_attr.get(attr).put(attr_sum, 0);
+ }
+
+ }
+
+ public FactAttrDistribution clone() {
+ return this.clone();
+ }
+
+ public void setTotal(int size) {
+ this.total_num = size;
+ }
+ public int getTotal() {
+ return this.total_num;
+ }
+
+ public FactTargetDistribution getAttrFor(Object attr_value) {
+ return facts_at_attr.get(attr_value);
+ }
+
+ public int getSumForAttr(Object attr_value) {
+ return facts_at_attr.get(attr_value).getSum();
+ }
+
+ public void setSumForAttr(Object attr_value, int total) {
+ facts_at_attr.get(attr_value).setSum(total);
+ }
+
+// public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
+// for (Object target: targetDist.keySet())
+// facts_at_attr.get(attr_value).put(target,targetDist.get(target));
+// }
+
+ public void setTargetDistForAttr(Object attr_value, FactTargetDistribution targetDist) {
+
+ //facts_at_attr.put(attr_value, targetDist);
+ /* TODO should i make a close */
+ FactTargetDistribution old = facts_at_attr.get(attr_value);
+ old.setDistribution(targetDist);
+
+ }
+
+ public void change(Object attrValue, Object targetValue, int i) {
+ facts_at_attr.get(attrValue).change(targetValue, i);
+
+ facts_at_attr.get(attrValue).change(attr_sum, i);
+/* int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
+ num_1 += i;
+ facts_at_attr.get(attrValue).put(targetValue, num_1);
+
+ int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
+ total_num_1 += i;
+ facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
+ */
+
+ }
+
+ public Collection<Object> getAttributes() {
+ return facts_at_attr.keySet();
+ }
+
+
+}
Deleted: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -1,78 +0,0 @@
-package dt.memory;
-
-import java.util.Collection;
-import java.util.Hashtable;
-import java.util.List;
-
-import dt.tools.Util;
-
-public class FactDistribution {
- String attr_sum = Util.sum();
-
- Hashtable<Object, Hashtable<Object, Integer>> facts_at_attr;
-
- private int total_num;
-
- public FactDistribution(List<?> attributeValues, List<?> targetValues) {
- facts_at_attr = new Hashtable<Object, Hashtable<Object, Integer>>(attributeValues.size());
-
- for (Object attr : attributeValues) {
- facts_at_attr.put(attr, new Hashtable<Object, Integer>(targetValues.size() + 1));
- for (Object t : targetValues) {
- facts_at_attr.get(attr).put(t, 0);
- }
- facts_at_attr.get(attr).put(attr_sum, 0);
- }
-
- }
-
- public FactDistribution clone() {
- return this.clone();
- }
-
- public void setTotal(int size) {
- this.total_num = size;
- }
- public int getTotal() {
- return this.total_num;
- }
-
- public Hashtable<Object, Integer> getAttrFor(Object attr_value) {
- return facts_at_attr.get(attr_value);
- }
-
- public int getSumForAttr(Object attr_value) {
- return facts_at_attr.get(attr_value).get(attr_sum).intValue();
- }
-
- public void setSumForAttr(Object attr_value, int total) {
- facts_at_attr.get(attr_value).put(attr_sum,total);
- }
-
- public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
- for (Object target: targetDist.keySet())
- facts_at_attr.get(attr_value).put(target,targetDist.get(target));
- }
-
- public void setTargetDistForAttr(Object attr_value, FactTargetDistribution targetDist) {
- for (Object target: targetDist.getTargetClasses())
- facts_at_attr.get(attr_value).put(target,targetDist.getVoteFor(target));
- }
-
- public void change(Object attrValue, Object targetValue, int i) {
- int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
- num_1 += i;
- facts_at_attr.get(attrValue).put(targetValue, num_1);
-
- int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
- total_num_1 += i;
- facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
-
- }
-
- public Collection<Object> getAttributes() {
- return facts_at_attr.keySet();
- }
-
-
-}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -1,107 +1,65 @@
package dt.memory;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Hashtable;
import java.util.List;
import dt.tools.Util;
+/* simple histogram keeps the number of facts in each class of target */
public class FactTargetDistribution {
private String attr_sum = Util.sum();
- private Domain<?> targetDomain;
+ protected Domain<?> targetDomain;
private Hashtable<Object, Integer> num_at_target;
- private Hashtable<Object, List<Fact>> facts_at_target;
- private int num_supported_target_classes;
- private Object the_winner_target_class;
-
public FactTargetDistribution(Domain<?> targetDomain) {
-
-// this.targetDomain = targetDomain.clone();
-// targetDomain.
-
- num_supported_target_classes = 0;
+
this.targetDomain = targetDomain;
List<?> targetValues = targetDomain.getValues();
- num_at_target = new Hashtable<Object, Integer>(targetValues.size() + 1);
- facts_at_target = new Hashtable<Object, List<Fact>>(targetValues.size());
+ num_at_target = new Hashtable<Object, Integer>(targetValues.size() + 1);
for (Object t : targetValues) {
- num_at_target.put(t, 0);
- facts_at_target.put(t, new ArrayList<Fact>());
+ num_at_target.put(t, 0);
}
num_at_target.put(attr_sum, 0);
}
- public void calculateDistribution(List<Fact> facts){
- int total_num_facts = 0;
- String target = targetDomain.getName();
- for (Fact f : facts) {
- total_num_facts++;
- Object key = f.getFieldValue(target);
- // System.out.println("My key: "+ key.toString());
- num_at_target.put(key, num_at_target.get(key).intValue() + 1); // bocuk
- facts_at_target.get(key).add(f);
-
- }
- num_at_target.put(attr_sum, num_at_target.get(attr_sum).intValue() + total_num_facts);
-
+ public Collection<?> getTargetClasses() {
+ return targetDomain.getValues();
}
- public Collection<Object> getTargetClasses() {
- return facts_at_target.keySet();
- }
+
public int getSum() {
return num_at_target.get(attr_sum).intValue();
}
+ public void setSum(int sum) {
+ num_at_target.put(attr_sum, sum);
+ }
public int getVoteFor(Object value) {
return num_at_target.get(value).intValue();
}
- public List<Fact> getSupportersFor(Object value) {
- return facts_at_target.get(value);
+ public Domain<?> getTargetDomain() {
+ return targetDomain;
}
- public void evaluateMajority() {
-
- List<?> targetValues = targetDomain.getValues();
- int winner_vote = 0;
- int num_supporters = 0;
-
- Object winner = null;
- for (Object key : targetValues) {
- int num_in_class = num_at_target.get(key).intValue();
- if (num_in_class > 0)
- num_supporters++;
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
- setNum_supperted_target_classes(num_supporters);
- setThe_winner_target_class(winner);
-
+ public void setTargetDomain(Domain<?> targetDomain) {
+ this.targetDomain = targetDomain.clone();
}
-
- public int getNum_supported_target_classes() {
- return num_supported_target_classes;
+
+ public void change(Object targetValue, int i) {
+ int num_1 = num_at_target.get(targetValue).intValue();
+ num_1 += i;
+ num_at_target.put(targetValue, num_1);
}
- public void setNum_supperted_target_classes(int num_supperted_target_classes) {
- this.num_supported_target_classes = num_supperted_target_classes;
+ public void setDistribution(FactTargetDistribution targetDist) {
+ for (Object targetValue: targetDomain.getValues()) {
+ num_at_target.put(targetValue, targetDist.getVoteFor(targetValue));
+ }
+
}
-
- public Object getThe_winner_target_class() {
- return the_winner_target_class;
- }
-
- public void setThe_winner_target_class(Object the_winner_target_class) {
- this.the_winner_target_class = the_winner_target_class;
- }
-
-
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -8,6 +8,7 @@
import dt.memory.Domain;
import dt.memory.Fact;
+import dt.memory.FactDistribution;
import dt.memory.FactTargetDistribution;
public class FactProcessor {
@@ -101,7 +102,7 @@
}
public static void splitUnclassifiedFacts(
- List<Fact> unclassified_facts, FactTargetDistribution stats) {
+ List<Fact> unclassified_facts, FactDistribution stats) {
Object winner = stats.getThe_winner_target_class();
System.out.println(Util.ntimes("DANIEL", 2)+ " lets get unclassified daniel winner "+winner +" num of sup " +stats.getVoteFor(winner));
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -22,7 +22,7 @@
dt = System.currentTimeMillis() - dt;
System.out.println("Time" + dt + "\n" + bocuksTree);
- RulePrinter my_printer = new RulePrinter();
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
boolean sort_via_rank = true;
my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -39,14 +39,13 @@
this.num_facts = num_facts;
}
- public RulePrinter() {
- ruleText = new ArrayList<String>();
- //rule_list = new ArrayList<ArrayList<NodeValue>>();
- rules = new ArrayList<Rule>();
-
- /* most important */
- nodes = new Stack<NodeValue>();
+ public int getNum_facts() {
+ return num_facts;
}
+
+ public void setNum_facts(int num_facts) {
+ this.num_facts = num_facts;
+ }
public void printer(DecisionTree dt, String packageName, String outputFile, boolean sort) {//, PrintStream object
ruleObject = dt.getName();
@@ -68,11 +67,12 @@
Collections.sort(rules, Rule.getRankComparator());
int total_num_facts=0;
- int i = 0;
+ int i = 0, active_i = 0;
for( Rule rule: rules) {
i++;
if (ONLY_ACTIVE) {
if (rule.getRank() >= 0) {
+ active_i++;
System.out.println("//Active rules " +i + " write to drl \n"+ rule +"\n");
if (outputFile!=null) {
write(rule.toString(), true, outputFile);
@@ -81,16 +81,20 @@
}
} else {
+ if (rule.getRank() >= 0) {
+ active_i++;
+ }
System.out.println("//rule " +i + " write to drl \n"+ rule +"\n");
if (outputFile!=null) {
write(rule.toString(), true, outputFile);
write("\n", true, outputFile);
}
}
- total_num_facts += rule.getPopularity();
+ total_num_facts += rule.getPopularity();
}
if (outputFile!=null) {
- write("//THE END: Total number of facts correctly classified= "+ total_num_facts, true, outputFile);
+ write("//THE END: Total number of facts correctly classified= "+ total_num_facts + " over "+ getNum_facts() , true, outputFile);
+ write("\n//with " + active_i + " number of rules over "+i+" total number of rules ", true, outputFile);
write("\n", true, outputFile); // EOF
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -39,7 +39,7 @@
System.out.println("Time"+dt + " facts read: "+bocuksTree.getNumRead() + " num call: "+ bocuk.getNumCall() );
//System.out.println(bocuksTree);
- RulePrinter my_printer = new RulePrinter();
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
boolean sort_via_rank = true;
my_printer.printer(bocuksTree, null, null, sort_via_rank);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-04-01 23:53:50 UTC (rev 19371)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-04-02 01:29:16 UTC (rev 19372)
@@ -45,7 +45,7 @@
dt = System.currentTimeMillis() - dt;
System.out.println("Time"+dt+"\n"+bocuksTree);
- RulePrinter my_printer = new RulePrinter();
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
boolean sort_via_rank = true;
my_printer.printer(bocuksTree,"test" , new String("../dt_learning/src/test/rules"+".drl"), sort_via_rank);
}
More information about the jboss-svn-commits
mailing list