[jboss-svn-commits] JBL Code SVN: r19333 - in labs/jbossrules/contrib/machinelearning/decisiontree/src/dt: memory and 1 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Mon Mar 31 02:17:11 EDT 2008
Author: gizil
Date: 2008-03-31 02:17:11 -0400 (Mon, 31 Mar 2008)
New Revision: 19333
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
Log:
binary discretization
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -188,13 +188,8 @@
TreeNode currentNode = new TreeNode(choosenDomain);
- Hashtable<Object, List<Fact>> filtered_facts = null;
+ Hashtable<Object, List<Fact>> filtered_facts = FactProcessor.splitFacts(facts, choosenDomain);
- if (choosenDomain.isDiscrete()) {
- filtered_facts = FactProcessor.splitFacts_disc(facts, choosenDomain.getName(), choosenDomain.getValues());
- } else {
- filtered_facts = FactProcessor.splitFacts_cont(facts, choosenDomain);
- }
dt.FACTS_READ += facts.size();
for (Object value : filtered_facts.keySet()) {
@@ -206,13 +201,11 @@
if (filtered_facts.get(value).isEmpty()) {
/* majority !!!! */
- LeafNode majorityNode = new LeafNode(dt.getDomain(dt
- .getTarget()), winner);
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
majorityNode.setRank(0.0);
currentNode.addNode(value, majorityNode);
} else {
- TreeNode newNode = c45(dt, filtered_facts.get(value),
- attributeNames_copy);
+ TreeNode newNode = c45(dt, filtered_facts.get(value), attributeNames_copy);
currentNode.addNode(value, newNode);
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -22,70 +22,40 @@
double dt_info = calc_info(facts_in_class, facts.size());
double greatestGain = -100000.0;
String attributeWithGreatestGain = attrs.get(0);
- Domain<?> domainWithGreatestGain = dt.getDomain(attributeWithGreatestGain);
-
+ Domain attrDomain = dt.getDomain(attributeWithGreatestGain);
Domain bestDomain = null;
- List<Object> bestValues = new ArrayList<Object>();
- String target = dt.getTarget();
- List<?> targetValues = dt.getPossibleValues(target);
+
+ Domain<?> targetDomain = dt.getDomain(dt.getTarget());
for (String attr : attrs) {
System.out.println("Which attribute to try: "+ attr);
double gain = 0;
- List<Fact> splitValues = null;
if (dt.getDomain(attr).isDiscrete()) {
- List<?> attributeValues = dt.getPossibleValues(attr);
- gain = dt_info - info_attr(facts, attr, attributeValues, target, targetValues);
+ attrDomain = dt.getDomain(attr).clone();
+ for (Object v: dt.getDomain(attr).getValues())
+ attrDomain.addValue(v);
+ gain = dt_info - info_attr(facts, attrDomain, targetDomain);
} else {
/* 1. sort the values */
Collections.sort(facts, facts.get(0).getDomain(attr).factComparator());
List<Fact> splits = getSplitPoints(facts, dt.getTarget());
- splitValues = new ArrayList<Fact>();
- splitValues.add(facts.get(facts.size()-1));
- System.out.println("Entropy.chooseContAttribute() hacking the representatives 1: "+ splitValues.size());
- for (Object v: splitValues) {
- System.out.println("Entropy.chooseContAttribute() splitValues:"+(Fact)v);
- }
- gain = dt_info - info_contattr(facts, attr, splitValues,
- target, targetValues,
- facts_in_class, splits);
- System.out.println("entropy.chooseContAttribute(1)*********** hey the new values to split: "+ splitValues.size());
+ attrDomain = dt.getDomain(attr).clone();
+ attrDomain.addPseudoValue(facts.get(facts.size()-1).getFieldValue(attr));
+ System.out.println("entropy.chooseContAttribute(1)*********** hey the new values to split: "+ attrDomain.getValues().get(0));
+
+ gain = dt_info - info_contattr(facts, attrDomain, targetDomain,
+ facts_in_class, splits);
+ System.out.println("entropy.chooseContAttribute(2)*********** hey the new values to split: "+ attrDomain.getValues().size());
}
- if (gain > greatestGain) {
-
- bestValues.clear();
+ if (gain > greatestGain) {
greatestGain = gain;
attributeWithGreatestGain = attr;
- domainWithGreatestGain = dt.getDomain(attr);
- if (domainWithGreatestGain.isDiscrete()) {
- for (Object value: domainWithGreatestGain.getValues())
- bestValues.add(value);
- } else {
- System.out.println("entropy.chooseContAttribute(2)*********** hey the new values to split: "+ splitValues.size());
-
- for (Fact f: splitValues)
- bestValues.add(f);
- }
+ bestDomain = attrDomain;
}
}
- bestDomain = domainWithGreatestGain.clone();
- if (bestDomain.isDiscrete()) {
- for (Object v: bestValues)
- bestDomain.addValue(v);
- } else {
- /* it is a hack fix it */
- System.out.println("entropy.chooseContAttribute(last)*********** hey the new values to split: "+ bestValues.size());
- for (Object v: bestValues) {
- System.out.println("Entropy.chooseContAttribute() fact:"+(Fact)v);
- ((NumericDomain)bestDomain).addRepresentative((Fact)v);
- }
- System.out.println("entropy.chooseContAttribute(after)*********** hey the new values to split: "+ ((NumericDomain)bestDomain).getRepresentatives().size());
-
- //Collections.sort(((NumericDomain)bestDomain).getRepresentatives(), bestDomain.factComparator());
- }
return bestDomain;
}
@@ -115,13 +85,14 @@
* instances of a single class or (b) some stopping criterion is reached. I
* can't remember what stopping criteria they used.
*/
-
- // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
public static double info_contattr(List<Fact> facts,
- String splitAttr, List<Fact> splitValues,
- String targetAttr,List<?> targetValues,
+ Domain splitDomain, Domain<?> targetDomain,
Hashtable<Object, Integer> facts_in_class,
List<Fact> split_facts) {
+ String splitAttr = splitDomain.getName();
+ List<?> splitValues = splitDomain.getValues();
+ String targetAttr = targetDomain.getName();
+ List<?> targetValues = targetDomain.getValues();
System.out.println("What is the attributeToSplit? " + splitAttr);
@@ -135,7 +106,6 @@
.println("The size of the splits is 0 oups??? exiting....");
System.exit(0);
}
-
/* initialize the distribution */
Object key0 = Integer.valueOf(0);
@@ -150,42 +120,24 @@
facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
facts_at_attribute.setSumForAttr(key1, facts.size());
-// Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
-// new Hashtable<Object, Hashtable<Object, Integer>>(splitValues.size()+1);
-// // attr_0 bhas nothing everything inside attr_1
-//
-//
-// facts_of_attribute.put(key1,
-// new Hashtable<Object, Integer>(targetValues.size() + 1));
-// for (Object t : targetValues) {
-// facts_of_attribute.get(key1).put(t, facts_in_class.get(t));
-// }
-// facts_of_attribute.get(key1).put(attr_sum, facts.size());
-
- /*
- * 2. Look for potential cut-points.
- */
- double best_sum = 100000.0;
- Fact fact_to_split = splitValues.get(0);
+ double best_sum = +100000.0;
+ Object value_to_split = splitValues.get(0);
int split_index, index = 1;
Iterator<Fact> f_ite = facts.iterator();
Fact f1 = f_ite.next();
- while (f_ite.hasNext()) {
+ while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
Fact f2 = f_ite.next();
-
- // everytime it is not a split change the place in the distribution
-
Object targetKey = f2.getFieldValue(targetAttr);
// System.out.println("My key: "+ targetKey.toString());
//for (Object attr_key : attr_values)
+ /* every time it change the place in the distribution */
facts_at_attribute.change(key0, targetKey, +1);
facts_at_attribute.change(key1, targetKey, -1);
-
/*
* 2.1 Cut points are points in the sorted list above where the class labels change.
* Eg. if I had five instances with values for the attribute of interest and labels
@@ -209,18 +161,17 @@
if (sum < best_sum) {
best_sum = sum;
- fact_to_split = f2;
+ value_to_split = cut_point;
System.out.println("Entropy.info_contattr() hacking: "+ sum + " best sum "+best_sum +
- " new fact value "+ fact_to_split.getFieldValue(splitAttr));
+ " new split value "+ value_to_split);
split_index = index;
}
} else {}
f1 = f2;
index++;
}
+ splitDomain.addValue(value_to_split);
- splitValues.add(fact_to_split);
-
System.out.println("*********** hey the new values to split: "+ splitValues.size());
return best_sum;
}
@@ -229,7 +180,6 @@
/*
* id3 uses that function because it can not classify continuous attributes
*/
-
public static String chooseAttribute(DecisionTree dt, List<Fact> facts,
Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
@@ -237,16 +187,14 @@
double greatestGain = -1000;
String attributeWithGreatestGain = attrs.get(0);
String target = dt.getTarget();
- List<?> targetValues = dt.getPossibleValues(target);
+ Domain<?> targetDomain = dt.getDomain(target);
for (String attr : attrs) {
double gain = 0;
if (!dt.getDomain(attr).isDiscrete()) {
System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
continue;
} else {
- List<?> attributeValues = dt.getPossibleValues(attr);
-
- gain = dt_info - info_attr(facts, attr, attributeValues, target, targetValues);
+ gain = dt_info - info_attr(facts, dt.getDomain(attr), targetDomain);
}
System.out.println("Attribute: " + attr + " the gain: " + gain);
if (gain > greatestGain) {
@@ -260,38 +208,19 @@
return attributeWithGreatestGain;
}
-
-
-// public double gain(List<Fact> facts,
-// Hashtable<Object, Integer> facts_in_class, String attributeName) {
-// List<?> attributeValues = getPossibleValues(attributeName);
-// List<?> targetValues = getPossibleValues(getTarget());
-//
-// return Entropy.info(facts_in_class, facts.size())
-// - Entropy.info_attr(facts, attributeName, getTarget(), attributeValues, targetValues);
-// }
-
-
- // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
- // {
public static double info_attr(List<Fact> facts,
- String attributeToSplit, List<?> attributeValues,
- String target, List<?> targetValues) {
+ Domain<?> splitDomain, Domain<?> targetDomain) {
+ String attributeToSplit = splitDomain.getName();
+ List<?> attributeValues = splitDomain.getValues();
+ String target = targetDomain.getName();
+ List<?> targetValues = targetDomain.getValues();
+
System.out.println("What is the attributeToSplit? " + attributeToSplit);
- //List<?> attributeValues = getPossibleValues(attributeToSplit);
- String attr_sum = Util.sum();
-
- //List<?> targetValues = getPossibleValues(getTarget());
- // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
- // Integer>(targetValues.size());
-
/* initialize the hashtable */
FactDistribution facts_at_attribute = new FactDistribution(attributeValues, targetValues);
facts_at_attribute.setTotal(facts.size());
- // *OPT* for (FactSet fs: facts) {
- // *OPT* for (Fact f: fs.getFacts()) {
for (Fact f : facts) {
Object targetKey = f.getFieldValue(target);
// System.out.println("My key: "+ targetKey.toString());
@@ -331,24 +260,17 @@
*/
public static double calc_info(Hashtable<Object, Integer> facts_in_class,
int total_num_facts) {
- // List<?> targetValues = getPossibleValues(this.target);
- // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
- // getTarget()); //, targetValues);
Collection<Object> targetValues = facts_in_class.keySet();
double prob, sum = 0;
for (Object key : targetValues) {
int num_in_class = facts_in_class.get(key).intValue();
- // System.out.println("num_in_class : "+ num_in_class + " key "+ key
- // + " and the total num "+ total_num_facts);
+ // System.out.println("num_in_class : "+ num_in_class + " key "+ key+ " and the total num "+ total_num_facts);
if (num_in_class > 0) {
prob = (double) num_in_class / (double) total_num_facts;
/* TODO what if it is a sooo small number ???? */
- // double log2= Util.log2(prob);
- // double plog2p= prob*log2;
sum += -1 * prob * Util.log2(prob);
- // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
- // where the sum: "+sum);
+ // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"where the sum: "+sum);
}
}
return sum;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -187,8 +187,7 @@
/* the majority */
List<?> attributeValues = dt.getPossibleValues(chosenAttribute);
- Hashtable<Object, List<Fact> > filtered_facts =
- FactProcessor.splitFacts_disc(facts, chosenAttribute, attributeValues);
+ Hashtable<Object, List<Fact> > filtered_facts = FactProcessor.splitFacts_disc(facts, dt.getDomain(chosenAttribute));
dt.FACTS_READ += facts.size();
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -40,9 +40,11 @@
}
public void addValue(Boolean value) {
- // TODO Auto-generated method stub
-
+ return;
}
+ public void addPseudoValue(Boolean value) {
+ return;
+ }
public List<Boolean> getValues() {
return fValues;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -16,6 +16,7 @@
String getName();
void addValue(T value);
+ void addPseudoValue(T fieldValue);
List<T> getValues();
@@ -30,6 +31,7 @@
Comparator<Fact> factComparator();
public Domain<T> clone();
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -51,14 +51,23 @@
public void addValue(String value) {
if (constant)
return;
- //if (discrete) {
- if (!fValues.contains(value))
+ if (discrete) {
+ if (!fValues.contains(value))
fValues.add(value);
-// } else {
-// fValues.add(value);
-// }
+ } else {
+ return;
+ }
}
+ public void addPseudoValue(String value) {
+ if (discrete) {
+ return;
+ } else {
+ if (!fValues.contains(value))
+ fValues.add(value);
+ }
+
+ }
public boolean contains(String value) {
for(String n: fValues) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -54,10 +54,20 @@
if (!fValues.contains(value))
fValues.add(value);
} else {
-
+ return;
}
}
+
+ public void addPseudoValue(Number value) {
+ if (discrete) {
+ return;
+ } else {
+ if (!fValues.contains(value))
+ fValues.add(value);
+ }
+
+ }
public void addRepresentative(Fact f) {
if (!representatives.contains(f))
representatives.add(f);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -8,13 +8,21 @@
import dt.memory.Domain;
import dt.memory.Fact;
-import dt.memory.NumericDomain;
public class FactProcessor {
-
+ public static Hashtable<Object, List<Fact>> splitFacts(
+ List<Fact> facts, Domain<?> choosenDomain) {
+ if (choosenDomain.isDiscrete()) {
+ return FactProcessor.splitFacts_disc(facts, choosenDomain);
+ } else {
+ return FactProcessor.splitFacts_cont(facts, choosenDomain);
+ }
+ }
public static Hashtable<Object, List<Fact>> splitFacts_disc(
- List<Fact> facts, String attributeName, List<?> attributeValues) {
+ List<Fact> facts, Domain<?> choosenDomain) {
+ String attributeName = choosenDomain.getName();
+ List<?> attributeValues = choosenDomain.getValues();
Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(attributeValues.size());
for (Object v : attributeValues) {
factLists.put(v, new ArrayList<Fact>());
@@ -26,33 +34,36 @@
}
/* it must work */
- public static Hashtable<Object, List<Fact>> splitFacts_cont(
+ private static Hashtable<Object, List<Fact>> splitFacts_cont(
List<Fact> facts, Domain<?> attributeDomain) {
+
String attributeName = attributeDomain.getName();
System.out.println("FactProcessor.splitFacts_cont() kimi diziyoruz: "+ attributeName);
- List<Fact> categorization = ((NumericDomain)attributeDomain).getRepresentatives();
+ List<?> categorization = attributeDomain.getValues();
System.out.println("FactProcessor.splitFacts_cont() haniymis benim repsentativelerim: "+ categorization.size());
- for (Fact f: categorization)
- System.out.println("FactProcessor.splitFacts_cont() haniymis benim factim: "+ f.getFieldValue(attributeName));
- Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(
- categorization.size());
- for (Fact v : categorization) {
+ Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(categorization.size());
+ for (Object v: attributeDomain.getValues()) {
factLists.put(v, new ArrayList<Fact>());
}
for (Fact f : facts) {
Comparator<Fact> cont_comp = attributeDomain.factComparator();
- ListIterator<Fact> category_it = categorization
- .listIterator(categorization.size() - 1);
+ ListIterator<?> category_it = attributeDomain.getValues().listIterator(attributeDomain.getValues().size() - 1);
while (category_it.hasPrevious()) {
- Fact category_fact = category_it.previous();
-
- if (cont_comp.compare(f, category_fact) < 0) {
- factLists.get(category_fact.getFieldValue(attributeName))
- .add(f);
+ Object category = category_it.previous();
+ Fact pseudo = new Fact();
+ try {
+ pseudo.add(attributeDomain, category);
+ if (cont_comp.compare(f, pseudo) < 0) {
+ factLists.get(category).add(f);
+ }
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
}
+
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-03-31 05:59:54 UTC (rev 19332)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-03-31 06:17:11 UTC (rev 19333)
@@ -110,7 +110,6 @@
Rule newRule = new Rule(nodes.size());// (nodes, leaves) //if more than one leaf
newRule.setObject(getRuleObject().toString());
Iterator<NodeValue> it = nodes.iterator();
-
while (it.hasNext()) {
NodeValue current = it.next();
@@ -320,10 +319,16 @@
this.nodeValue = nodeValue;
}
public String toString() {
+ String value;
if (node.getDomain() instanceof LiteralDomain)
- return node.getDomain() + " == "+ "\""+nodeValue+ "\"";
+ value = "\""+nodeValue+ "\"";
else
- return node.getDomain() + " == "+ nodeValue;
+ value = nodeValue + "";
+
+ if (node.getDomain().isDiscrete())
+ return node.getDomain() + " == "+ value;
+ else
+ return node.getDomain() + " <= "+ value;
}
}
More information about the jboss-svn-commits
mailing list