[jboss-svn-commits] JBL Code SVN: r19330 - labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Mon Mar 31 00:10:48 EDT 2008
Author: gizil
Date: 2008-03-31 00:10:48 -0400 (Mon, 31 Mar 2008)
New Revision: 19330
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/InformationMeasure.java
Log:
hackish working discretization
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-03-31 04:10:48 UTC (rev 19330)
@@ -0,0 +1,373 @@
+package dt.builder;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Hashtable;
+import java.util.Iterator;
+import java.util.List;
+
+import dt.DecisionTree;
+import dt.memory.Domain;
+import dt.memory.Fact;
+import dt.memory.FactDistribution;
+import dt.tools.Util;
+import dt.memory.NumericDomain;
+
+public class Entropy implements InformationMeasure {
+
+ public static Domain<?> chooseContAttribute(DecisionTree dt, List<Fact> facts,
+ Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+
+ double dt_info = calc_info(facts_in_class, facts.size());
+ double greatestGain = -100000.0;
+ String attributeWithGreatestGain = attrs.get(0);
+ Domain<?> domainWithGreatestGain = dt.getDomain(attributeWithGreatestGain);
+
+ Domain bestDomain = null;
+ List<Object> bestValues = new ArrayList<Object>();
+ String target = dt.getTarget();
+ List<?> targetValues = dt.getPossibleValues(target);
+ for (String attr : attrs) {
+ System.out.println("Which attribute to try: "+ attr);
+ double gain = 0;
+ List<Fact> splitValues = null;
+ if (dt.getDomain(attr).isDiscrete()) {
+ List<?> attributeValues = dt.getPossibleValues(attr);
+ gain = dt_info - info_attr(facts, attr, attributeValues, target, targetValues);
+
+ } else {
+ /* 1. sort the values */
+ Collections.sort(facts, facts.get(0).getDomain(attr).factComparator());
+
+ List<Fact> splits = getSplitPoints(facts, dt.getTarget());
+ splitValues = new ArrayList<Fact>();
+ splitValues.add(facts.get(facts.size()-1));
+ System.out.println("Entropy.chooseContAttribute() hacking the representatives 1: "+ splitValues.size());
+ for (Object v: splitValues) {
+ System.out.println("Entropy.chooseContAttribute() splitValues:"+(Fact)v);
+ }
+ gain = dt_info - info_contattr(facts, attr, splitValues,
+ target, targetValues,
+ facts_in_class, splits);
+ System.out.println("entropy.chooseContAttribute(1)*********** hey the new values to split: "+ splitValues.size());
+
+ }
+
+ if (gain > greatestGain) {
+
+ bestValues.clear();
+ greatestGain = gain;
+ attributeWithGreatestGain = attr;
+ domainWithGreatestGain = dt.getDomain(attr);
+ if (domainWithGreatestGain.isDiscrete()) {
+ for (Object value: domainWithGreatestGain.getValues())
+ bestValues.add(value);
+ } else {
+ System.out.println("entropy.chooseContAttribute(2)*********** hey the new values to split: "+ splitValues.size());
+
+ for (Fact f: splitValues)
+ bestValues.add(f);
+ }
+ }
+ }
+ bestDomain = domainWithGreatestGain.clone();
+ if (bestDomain.isDiscrete()) {
+ for (Object v: bestValues)
+ bestDomain.addValue(v);
+ } else {
+ /* it is a hack fix it */
+ System.out.println("entropy.chooseContAttribute(last)*********** hey the new values to split: "+ bestValues.size());
+ for (Object v: bestValues) {
+ System.out.println("Entropy.chooseContAttribute() fact:"+(Fact)v);
+ ((NumericDomain)bestDomain).addRepresentative((Fact)v);
+ }
+ System.out.println("entropy.chooseContAttribute(after)*********** hey the new values to split: "+ ((NumericDomain)bestDomain).getRepresentatives().size());
+
+ //Collections.sort(((NumericDomain)bestDomain).getRepresentatives(), bestDomain.factComparator());
+ }
+
+ return bestDomain;
+ }
+
+ /*
+ * GLOBAL DISCRETIZATION a a b a b b b b b (target) 1 2 3 4 5 6 7 8 9 (attr
+ * c) 0 0 0 0 1 1 1 1 1 "<5", ">=5" "true" "false"
+ */
+ /*
+ * The algorithm is basically (per attribute):
+ *
+ * 1. Sort the instances on the attribute of interest
+ *
+ * 2. Look for potential cut-points. Cut points are points in the sorted
+ * list above where the class labels change. Eg. if I had five instances
+ * with values for the attribute of interest and labels (1.0,A), (1.4,A),
+ * (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only two cutpoints
+ * of interest: 1.85 and 5 (mid-way between the points where the classes
+ * change from A to B or vice versa).
+ *
+ * 3. Evaluate your favourite disparity measure (info gain, gain ratio, gini
+ * coefficient, chi-squared test) on each of the cutpoints, and choose the
+ * one with the maximum value (I think Fayyad and Irani used info gain).
+ *
+ * 4. Repeat recursively in both subsets (the ones less than and greater
+ * than the cutpoint) until either (a) the subset is pure i.e. only contains
+ * instances of a single class or (b) some stopping criterion is reached. I
+ * can't remember what stopping criteria they used.
+ */
+
+ // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
+ public static double info_contattr(List<Fact> facts,
+ String splitAttr, List<Fact> splitValues,
+ String targetAttr,List<?> targetValues,
+ Hashtable<Object, Integer> facts_in_class,
+ List<Fact> split_facts) {
+
+ System.out.println("What is the attributeToSplit? " + splitAttr);
+
+ if (facts.size() <= 1) {
+ System.out
+ .println("The size of the fact list is 0 oups??? exiting....");
+ System.exit(0);
+ }
+ if (split_facts.size() < 1) {
+ System.out
+ .println("The size of the splits is 0 oups??? exiting....");
+ System.exit(0);
+ }
+
+
+ /* initialize the distribution */
+ Object key0 = Integer.valueOf(0);
+ Object key1 = Integer.valueOf(1);
+ List<Object> keys = new ArrayList<Object>(2);
+ keys.add(key0);
+ keys.add(key1);
+
+
+ FactDistribution facts_at_attribute = new FactDistribution(keys, targetValues);
+ facts_at_attribute.setTotal(facts.size());
+ facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
+ facts_at_attribute.setSumForAttr(key1, facts.size());
+
+// Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
+// new Hashtable<Object, Hashtable<Object, Integer>>(splitValues.size()+1);
+// // attr_0 bhas nothing everything inside attr_1
+//
+//
+// facts_of_attribute.put(key1,
+// new Hashtable<Object, Integer>(targetValues.size() + 1));
+// for (Object t : targetValues) {
+// facts_of_attribute.get(key1).put(t, facts_in_class.get(t));
+// }
+// facts_of_attribute.get(key1).put(attr_sum, facts.size());
+
+ /*
+ * 2. Look for potential cut-points.
+ */
+ double best_sum = 100000.0;
+ Fact fact_to_split = splitValues.get(0);
+ int split_index, index = 1;
+
+ Iterator<Fact> f_ite = facts.iterator();
+ Fact f1 = f_ite.next();
+ while (f_ite.hasNext()) {
+
+ Fact f2 = f_ite.next();
+
+ // everytime it is not a split change the place in the distribution
+
+ Object targetKey = f2.getFieldValue(targetAttr);
+
+ // System.out.println("My key: "+ targetKey.toString());
+ //for (Object attr_key : attr_values)
+
+ facts_at_attribute.change(key0, targetKey, +1);
+ facts_at_attribute.change(key1, targetKey, -1);
+
+
+ /*
+ * 2.1 Cut points are points in the sorted list above where the class labels change.
+ * Eg. if I had five instances with values for the attribute of interest and labels
+ * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
+ * two cutpoints of interest: 1.85 and 5 (mid-way between the points
+ * where the classes change from A to B or vice versa).
+ */
+ if (f1.getFieldValue(targetAttr) != f2.getFieldValue(targetAttr)) {
+ // the cut point
+ Number cp_i = (Number) f1.getFieldValue(splitAttr);
+ Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
+
+ Number cut_point = (cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
+
+ /*
+ * 3. Evaluate your favourite disparity measure
+ * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
+ * and calculate its gain
+ */
+ double sum = calc_info_attr(facts_at_attribute);
+
+ if (sum < best_sum) {
+ best_sum = sum;
+ fact_to_split = f2;
+ System.out.println("Entropy.info_contattr() hacking: "+ sum + " best sum "+best_sum +
+ " new fact value "+ fact_to_split.getFieldValue(splitAttr));
+ split_index = index;
+ }
+ } else {}
+ f1 = f2;
+ index++;
+ }
+
+ splitValues.add(fact_to_split);
+
+ System.out.println("*********** hey the new values to split: "+ splitValues.size());
+ return best_sum;
+ }
+
+
+ /*
+ * id3 uses that function because it can not classify continuous attributes
+ */
+
+ public static String chooseAttribute(DecisionTree dt, List<Fact> facts,
+ Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+
+ double dt_info = calc_info(facts_in_class, facts.size());
+ double greatestGain = -1000;
+ String attributeWithGreatestGain = attrs.get(0);
+ String target = dt.getTarget();
+ List<?> targetValues = dt.getPossibleValues(target);
+ for (String attr : attrs) {
+ double gain = 0;
+ if (!dt.getDomain(attr).isDiscrete()) {
+ System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
+ continue;
+ } else {
+ List<?> attributeValues = dt.getPossibleValues(attr);
+
+ gain = dt_info - info_attr(facts, attr, attributeValues, target, targetValues);
+ }
+ System.out.println("Attribute: " + attr + " the gain: " + gain);
+ if (gain > greatestGain) {
+ greatestGain = gain;
+ attributeWithGreatestGain = attr;
+ }
+
+
+ }
+
+ return attributeWithGreatestGain;
+ }
+
+
+
+// public double gain(List<Fact> facts,
+// Hashtable<Object, Integer> facts_in_class, String attributeName) {
+// List<?> attributeValues = getPossibleValues(attributeName);
+// List<?> targetValues = getPossibleValues(getTarget());
+//
+// return Entropy.info(facts_in_class, facts.size())
+// - Entropy.info_attr(facts, attributeName, getTarget(), attributeValues, targetValues);
+// }
+
+
+ // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
+ // {
+ public static double info_attr(List<Fact> facts,
+ String attributeToSplit, List<?> attributeValues,
+ String target, List<?> targetValues) {
+ System.out.println("What is the attributeToSplit? " + attributeToSplit);
+ //List<?> attributeValues = getPossibleValues(attributeToSplit);
+
+ String attr_sum = Util.sum();
+
+ //List<?> targetValues = getPossibleValues(getTarget());
+ // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
+ // Integer>(targetValues.size());
+
+ /* initialize the hashtable */
+ FactDistribution facts_at_attribute = new FactDistribution(attributeValues, targetValues);
+ facts_at_attribute.setTotal(facts.size());
+
+ // *OPT* for (FactSet fs: facts) {
+ // *OPT* for (Fact f: fs.getFacts()) {
+ for (Fact f : facts) {
+ Object targetKey = f.getFieldValue(target);
+ // System.out.println("My key: "+ targetKey.toString());
+
+ Object attr_key = f.getFieldValue(attributeToSplit);
+ facts_at_attribute.change(attr_key, targetKey, +1);
+
+ // System.out.println("getGain of "+attributeToSplit+
+ // ": total_num "+ facts_of_attribute.get(attr_key).get(attr_sum) +
+ // " and "+facts_of_attribute.get(attr_key).get(targetKey) +
+ // " at attr=" + attr_key + " of t:"+targetKey);
+ }
+ double sum = calc_info_attr(facts_at_attribute);
+ return sum;
+ }
+
+ /*
+ * for both
+ */
+ private static double calc_info_attr( FactDistribution facts_of_attribute) {
+ Collection<Object> attributeValues = facts_of_attribute.getAttributes();
+ int fact_size = facts_of_attribute.getTotal();
+ double sum = 0.0;
+ for (Object attr : attributeValues) {
+ int total_num_attr = facts_of_attribute.getSumForAttr(attr);
+ //double sum_attr = 0.0;
+ if (total_num_attr > 0) {
+ sum += ((double) total_num_attr / (double) fact_size) *
+ calc_info(facts_of_attribute.getAttrFor(attr), total_num_attr);
+ }
+ }
+ return sum;
+ }
+
+ /*
+ *
+ */
+ public static double calc_info(Hashtable<Object, Integer> facts_in_class,
+ int total_num_facts) {
+ // List<?> targetValues = getPossibleValues(this.target);
+ // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
+ // getTarget()); //, targetValues);
+ Collection<Object> targetValues = facts_in_class.keySet();
+ double prob, sum = 0;
+ for (Object key : targetValues) {
+ int num_in_class = facts_in_class.get(key).intValue();
+ // System.out.println("num_in_class : "+ num_in_class + " key "+ key
+ // + " and the total num "+ total_num_facts);
+
+ if (num_in_class > 0) {
+ prob = (double) num_in_class / (double) total_num_facts;
+ /* TODO what if it is a sooo small number ???? */
+ // double log2= Util.log2(prob);
+ // double plog2p= prob*log2;
+ sum += -1 * prob * Util.log2(prob);
+ // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
+ // where the sum: "+sum);
+ }
+ }
+ return sum;
+ }
+
+ private static List<Fact> getSplitPoints(List<Fact> facts, String target) {
+ List<Fact> splits = new ArrayList<Fact>();
+ Iterator<Fact> it_f = facts.iterator();
+ Fact f1 = it_f.next();
+ int index = 0;
+ while(it_f.hasNext()){
+ Fact f2 = it_f.next();
+ if (f1.getFieldValue(target) != f2.getFieldValue(target))
+ splits.add(f2);
+
+ f1= f2;
+ index++;
+ }
+ return splits;
+ }
+
+}
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/InformationMeasure.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/InformationMeasure.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/InformationMeasure.java 2008-03-31 04:10:48 UTC (rev 19330)
@@ -0,0 +1,5 @@
+package dt.builder;
+
+public interface InformationMeasure {
+
+}
More information about the jboss-svn-commits
mailing list