[jboss-svn-commits] JBL Code SVN: r19315 - labs/jbossrules/contrib/machinelearning/decisiontree/src/id3.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Sat Mar 29 21:10:03 EDT 2008
Author: gizil
Date: 2008-03-29 21:10:03 -0400 (Sat, 29 Mar 2008)
New Revision: 19315
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DomainSpec.java
Removed:
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ReadingSeq.java
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BocukFileExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTree.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilderMT.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Fact.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/FactSetFactory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ObjectReader.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Util.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/WorkingMemory.java
Log:
c4.5 first steps: discretizing the continuous attrs
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BocukFileExample.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BocukFileExample.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,5 +1,7 @@
package id3;
+import java.util.List;
+
public class BocukFileExample {
public static void main(String[] args) {
@@ -37,26 +39,29 @@
}
}
- public static void processFileExample(Object emptyObject, String drlfile, String datafile, String separator, String target) {
+ public static List<Object> processFileExample(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator) {
- WorkingMemory simple = new WorkingMemory();
-
try {
- FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
DecisionTreeBuilder bocuk = new DecisionTreeBuilder();
long dt = System.currentTimeMillis();
- DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target, null);
+ String target_attr = Util.getTargetAnnotation(emptyObject.getClass());
+
+ DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, null);
dt = System.currentTimeMillis() - dt;
System.out.println("Time" + dt + "\n" + bocuksTree);
RulePrinter my_printer = new RulePrinter();
my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile);
+ return obj_read;
+
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
+ return null;
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BooleanDomain.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/BooleanDomain.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -8,6 +8,7 @@
private String fName;
private ArrayList<Boolean> fValues;
private boolean constant;
+ private int readingSeq;
public BooleanDomain(String _name) {
@@ -79,4 +80,19 @@
return out;
}
+ public void setReadingSeq(int readingSeq) {
+ this.readingSeq = readingSeq;
+
+ }
+
+ public int getReadingSeq() {
+ return this.readingSeq;
+
+ }
+
+ public void setDiscrete(boolean disc) {
+ // TODO Auto-generated method stub
+
+ }
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTree.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTree.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,25 +1,27 @@
package id3;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Hashtable;
+import java.util.Iterator;
import java.util.List;
+import java.util.Set;
public class DecisionTree {
-
+
public long FACTS_READ = 0;
-
- /* set of the attributes, their types*/
- private Hashtable<String, Domain<?>> domainSet;
+ /* set of the attributes, their types */
+ private Hashtable<String, Domain<?>> domainSet;
+
/* the class of the objects */
private String className;
/* the target attribute */
private String target;
-
private TreeNode root;
/* all attributes that can be used during classification */
@@ -30,15 +32,15 @@
this.domainSet = new Hashtable<String, Domain<?>>();
this.attrsToClassify = new ArrayList<String>();
}
-
-
+
private Object getConsensus(List<Fact> facts) {
- List<?> targetValues = getPossibleValues(this.target);
- Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target, targetValues);
-
+ List<?> targetValues = getPossibleValues(this.target);
+ Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target);
+ // , targetValues
+
int winner_vote = 0;
Object winner = null;
- for (Object key: targetValues) {
+ for (Object key : targetValues) {
int num_in_class = facts_in_class.get(key).intValue();
if (num_in_class > winner_vote) {
@@ -49,184 +51,495 @@
return winner;
}
+ // *OPT* public double calculateGain(List<FactSet> facts, String
+ // attributeName) {
+ // I dont use
+ public double calculateGain(List<Fact> facts,
+ Hashtable<Object, Integer> facts_in_class, String attributeName) {
-//*OPT* public double calculateGain(List<FactSet> facts, String attributeName) {
- public double calculateGain(List<Fact> facts, String attributeName) {
- return getInformation(facts) - getGain(facts, attributeName);
+ return getInformation(facts_in_class, facts.size())
+ - getGain(facts, attributeName);
}
-//*OPT* public double getGain(List<FactSet> facts, String attributeToSplit) {
+ // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
+ // {
public double getGain(List<Fact> facts, String attributeToSplit) {
- System.out.println("What is the attributeToSplit? "+attributeToSplit);
+ System.out.println("What is the attributeToSplit? " + attributeToSplit);
List<?> attributeValues = getPossibleValues(attributeToSplit);
String attr_sum = "sum";
- List<?> targetValues = getPossibleValues(target);
- //Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(targetValues.size());
+ List<?> targetValues = getPossibleValues(getTarget());
+ // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
+ // Integer>(targetValues.size());
/* initialize the hashtable */
- Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute = new Hashtable<Object, Hashtable<Object, Integer>>(attributeValues.size());
- for (Object attr: attributeValues) {
- facts_of_attribute.put(attr, new Hashtable<Object, Integer>(targetValues.size()+1));
- for (Object t: targetValues) {
+ Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute = new Hashtable<Object, Hashtable<Object, Integer>>(
+ attributeValues.size());
+ for (Object attr : attributeValues) {
+ facts_of_attribute.put(attr, new Hashtable<Object, Integer>(
+ targetValues.size() + 1));
+ for (Object t : targetValues) {
facts_of_attribute.get(attr).put(t, 0);
}
facts_of_attribute.get(attr).put(attr_sum, 0);
}
+ int total_num_facts = 0;
+ // *OPT* for (FactSet fs: facts) {
+ // *OPT* for (Fact f: fs.getFacts()) {
+ for (Fact f : facts) {
+ total_num_facts++;
+ Object targetKey = f.getFieldValue(target);
+ // System.out.println("My key: "+ targetKey.toString());
- int total_num_facts= 0;
-//*OPT* for (FactSet fs: facts) {
-//*OPT* for (Fact f: fs.getFacts()) {
- for (Fact f: facts) {
- total_num_facts ++;
- Object targetKey = f.getFieldValue(target);
- //System.out.println("My key: "+ targetKey.toString());
+ Object attr_key = f.getFieldValue(attributeToSplit);
+ int num = facts_of_attribute.get(attr_key).get(targetKey)
+ .intValue();
+ num++;
+ facts_of_attribute.get(attr_key).put(targetKey, num);
- Object attr_key = f.getFieldValue(attributeToSplit);
- int num = facts_of_attribute.get(attr_key).get(targetKey).intValue();
- num ++;
- facts_of_attribute.get(attr_key).put(targetKey, num);
+ int total_num = facts_of_attribute.get(attr_key).get(attr_sum)
+ .intValue();
+ total_num++;
+ facts_of_attribute.get(attr_key).put(attr_sum, total_num);
- int total_num = facts_of_attribute.get(attr_key).get(attr_sum).intValue();
- total_num ++;
- facts_of_attribute.get(attr_key).put(attr_sum, total_num);
-
-// System.out.println("getGain of "+attributeToSplit+
-// ": total_num "+ facts_of_attribute.get(attr_key).get(attr_sum) +
-// " and "+facts_of_attribute.get(attr_key).get(targetKey) +
-// " at attr=" + attr_key + " of t:"+targetKey);
+ // System.out.println("getGain of "+attributeToSplit+
+ // ": total_num "+ facts_of_attribute.get(attr_key).get(attr_sum) +
+ // " and "+facts_of_attribute.get(attr_key).get(targetKey) +
+ // " at attr=" + attr_key + " of t:"+targetKey);
}
FACTS_READ += facts.size();
-//*OPT* }
-//*OPT* }
+ // *OPT* }
+ // *OPT* }
+ double sum = getAttrInformation(facts_of_attribute, total_num_facts);
+// for (Object attr : attributeValues) {
+// int total_num_attr = facts_of_attribute.get(attr).get(attr_sum)
+// .intValue();
+//
+// double sum_attr = 0.0;
+// if (total_num_attr > 0)
+// for (Object t : targetValues) {
+// int num_attr_target = facts_of_attribute.get(attr).get(t)
+// .intValue();
+//
+// double prob = (double) num_attr_target / total_num_attr;
+// // System.out.println("prob "+ prob);
+// sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util
+// .log2(prob));
+// }
+// sum += ((double) total_num_attr / (double) total_num_facts)
+// * sum_attr;
+// }
+ return sum;
+ }
- double sum = 0.0;
- for (Object attr: attributeValues) {
- int total_num_attr = facts_of_attribute.get(attr).get(attr_sum).intValue();
-
- double sum_attr = 0.0;
- if (total_num_attr > 0)
- for (Object t: targetValues) {
- int num_attr_target = facts_of_attribute.get(attr).get(t).intValue();
+ /*
+ * GLOBAL DISCRETIZATION a a b a b b b b b (target) 1 2 3 4 5 6 7 8 9 (attr
+ * c) 0 0 0 0 1 1 1 1 1 "<5", ">=5" "true" "false"
+ */
+ /*
+ * The algorithm is basically (per attribute):
+ *
+ * 1. Sort the instances on the attribute of interest
+ *
+ * 2. Look for potential cut-points. Cut points are points in the sorted
+ * list above where the class labels change. Eg. if I had five instances
+ * with values for the attribute of interest and labels (1.0,A), (1.4,A),
+ * (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only two cutpoints
+ * of interest: 1.85 and 5 (mid-way between the points where the classes
+ * change from A to B or vice versa).
+ *
+ * 3. Evaluate your favourite disparity measure (info gain, gain ratio, gini
+ * coefficient, chi-squared test) on each of the cutpoints, and choose the
+ * one with the maximum value (I think Fayyad and Irani used info gain).
+ *
+ * 4. Repeat recursively in both subsets (the ones less than and greater
+ * than the cutpoint) until either (a) the subset is pure i.e. only contains
+ * instances of a single class or (b) some stopping criterion is reached. I
+ * can't remember what stopping criteria they used.
+ */
- double prob = (double)num_attr_target/total_num_attr;
- //System.out.println("prob "+ prob);
- sum_attr += (prob == 0.0) ? 0.0 : (-1* prob * Util.log2(prob));
+ // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
+ public double getContinuousGain(List<Fact> facts,
+ List<Integer> split_facts, int begin_index, int end_index,
+ Hashtable<Object, Integer> facts_in_class, String attributeToSplit) {
+
+ System.out.println("What is the attributeToSplit? " + attributeToSplit);
+
+ if (facts.size() <= 1) {
+ System.out
+ .println("The size of the fact list is 0 oups??? exiting....");
+ System.exit(0);
+ }
+ if (split_facts.size() < 1) {
+ System.out
+ .println("The size of the splits is 0 oups??? exiting....");
+ System.exit(0);
+ }
+
+ String targetAttr = getTarget();
+ List<?> targetValues = getPossibleValues(getTarget());
+ List<?> boundaries = getPossibleValues(attributeToSplit);
+
+ // Fact split_point = facts.get(facts.size() / 2);
+ // a b a a b
+ // 1 2 3 4 5
+ // 1.5
+ // 2.5
+ // 3.5
+ // 0.00001 0.00002 1 100
+ // 0.000015
+
+ // < 50 >
+ // 25 75
+ // HashTable<Boolean>
+
+ String attr_sum = Util.getSum();
+
+
+
+ /* initialize the hashtable */
+ Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
+ new Hashtable<Object, Hashtable<Object, Integer>>(Util.getDividingSize());
+ // attr_0 bhas nothing everything inside attr_1
+ Object cut_point; //attr_0
+ Object last_poit = facts.get(facts.size()-1).getFieldValue(attributeToSplit);
+ for (int i = 0; i < 2; i++) {
+ facts_of_attribute.put(Integer.valueOf(i),
+ new Hashtable<Object, Integer>(targetValues.size() + 1));
+ //Hashtable<Object, Integer> facts_in_class
+ if (i == 1) {
+ for (Object t : targetValues) {
+ facts_of_attribute.get(Integer.valueOf(i)).put(t,
+ facts_in_class.get(t));
}
- sum += ((double)total_num_attr/(double)total_num_facts) * sum_attr;
+ facts_of_attribute.get(Integer.valueOf(i)).put(attr_sum,
+ facts.size());
+ } else {
+ for (Object t : targetValues) {
+ facts_of_attribute.get(Integer.valueOf(i)).put(t, 0);
+ }
+ facts_of_attribute.get(Integer.valueOf(i)).put(attr_sum, 0);
+ }
}
- return sum;
- }
+
+ /*
+ * 2. Look for potential cut-points.
+ */
+
+ int split_index = 1;
+ int last_index = facts.size();
+ Iterator<Fact> f_ite = facts.iterator();
+ Fact f1 = f_ite.next();
+ while (f_ite.hasNext()) {
+
+ Fact f2 = f_ite.next();
+
+ // everytime it is not a split change the place in the distribution
+
+ Object targetKey = f2.getFieldValue(target);
+
+ // System.out.println("My key: "+ targetKey.toString());
+
+ //for (Object attr_key : attr_values)
+
+ Object attr_key_1 = Integer.valueOf(0);
+ int num_1 = facts_of_attribute.get(attr_key_1).get(targetKey).intValue();
+ num_1++;
+ facts_of_attribute.get(attr_key_1).put(targetKey, num_1);
+
+ int total_num_1 = facts_of_attribute.get(attr_key_1).get(attr_sum).intValue();
+ total_num_1++;
+ facts_of_attribute.get(attr_key_1).put(attr_sum, total_num_1);
+
+ Object attr_key_2= Integer.valueOf(1);
+ int num_2 = facts_of_attribute.get(attr_key_2).get(targetKey).intValue();
+ num_2--;
+ facts_of_attribute.get(attr_key_2).put(targetKey, num_2);
+
+ int total_num_2 = facts_of_attribute.get(attr_key_2).get(attr_sum).intValue();
+ total_num_2++;
+ facts_of_attribute.get(attr_key_2).put(attr_sum, total_num_2);
+
+ /*
+ * 2.1 Cut points are points in the sorted list above where the class labels change.
+ * Eg. if I had five instances with values for the attribute of interest and labels
+ * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
+ * two cutpoints of interest: 1.85 and 5 (mid-way between the points
+ * where the classes change from A to B or vice versa).
+ */
+ if (f1.getFieldValue(targetAttr) != f2.getFieldValue(targetAttr)) {
+ // the cut point
+ Number cp_i = (Number) f1.getFieldValue(attributeToSplit);
+ Number cp_i_next = (Number) f2.getFieldValue(attributeToSplit);
+
+ cut_point = (cp_i.doubleValue() + cp_i_next
+ .doubleValue()) / 2;
+
+ /*
+ * 3. Evaluate your favourite disparity measure
+ * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
+ * and calculate its gain
+ */
+ double sum = getAttrInformation(facts_of_attribute, facts.size());
+//
+// double sum = 0.0;
+// // for (Object attr : attributeValues) {
+// for (int i = 0; i < 2; i++) {
+//
+// int total_num_attr = facts_of_attribute.get(Integer.valueOf(i)).get(attr_sum).intValue();
+//
+// double sum_attr = 0.0;
+// if (total_num_attr > 0)
+// for (Object t : targetValues) {
+// int num_attr_target = facts_of_attribute.get(Integer.valueOf(i)).get(t).intValue();
+//
+// double prob = (double) num_attr_target / total_num_attr;
+// // System.out.println("prob "+ prob);
+// sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util.log2(prob));
+// }
+// sum += ((double) total_num_attr / (double) facts.size())* sum_attr;
+// }
-
- private class FactNumericAttributeComparator implements Comparator<Fact> {
- private String attr_name;
- public FactNumericAttributeComparator(String _attr_name) {
- attr_name = _attr_name;
+
+ } else {}
+
+// getContinuousGain(facts, split_facts.subList(0,
+// split_index+1), 0, split_index+1,
+// facts_in_class1, attributeToSplit);
+//
+// getContinuousGain(facts, split_facts.subList(split_index+1,
+// last_index), split_index+1, last_index,
+// facts_in_class2, attributeToSplit);
+
+ f1 = f2;
+ split_index ++;
}
- public int compare(Fact f0, Fact f1) {
- Number n0 = (Number)f0.getFieldValue(attr_name);
- Number n1 = (Number)f1.getFieldValue(attr_name);
- if (n0.doubleValue() < n1.doubleValue())
- return -1;
- else if (n0.doubleValue() > n1.doubleValue())
- return 1;
- else
- return 0;
- }
+ return 1.0;
}
- /* GLOBAL DISCRETIZATION
- a a b a b b b b b (target)
- 1 2 3 4 5 6 7 8 9 (attr c)
- 0 0 0 0 1 1 1 1 1
- "<5", ">=5"
- "true" "false"
- */
+ public double getContinuousGain_(List<Fact> facts,
+ List<Integer> split_facts, int begin_index, int end_index,
+ Hashtable<Object, Integer> facts_in_class, String attributeToSplit) {
+ System.out.println("What is the attributeToSplit? " + attributeToSplit);
- //*OPT* public double getGain(List<FactSet> facts, String attributeToSplit) {
- public double getContinuousGain(List<Fact> facts, String attributeToSplit) {
- System.out.println("What is the attributeToSplit? "+attributeToSplit);
+ if (facts.size() <= 1) {
+ System.out
+ .println("The size of the fact list is 0 oups??? exiting....");
+ System.exit(0);
+ }
+ if (split_facts.size() < 1) {
+ System.out
+ .println("The size of the splits is 0 oups??? exiting....");
+ System.exit(0);
+ }
+ String targetAttr = getTarget();
List<?> boundaries = getPossibleValues(attributeToSplit);
- /* sort the values */
- Collections.sort(facts, new FactNumericAttributeComparator(attributeToSplit));
+ // Fact split_point = facts.get(facts.size() / 2);
+ // a b a a b
+ // 1 2 3 4 5
+ // 1.5
+ // 2.5
+ // 3.5
+ // 0.00001 0.00002 1 100
+ // 0.000015
- //Fact split_point = facts.get(facts.size() / 2);
- // a b a a b
- // 1 2 3 4 5
- // 1.5
- // 2.5
- // 3.5
- // 0.00001 0.00002 1 100
- // 0.000015
-
- // < 50 >
- // 25 75
- //HashTable<Boolean>
-
+ // < 50 >
+ // 25 75
+ // HashTable<Boolean>
+
String attr_sum = "sum";
+ /*
+ * 2. Look for potential cut-points. Cut points are points in the sorted
+ * list above where the class labels change. Eg. if I had five instances
+ * with values for the attribute of interest and labels (1.0,A),
+ * (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
+ * two cutpoints of interest: 1.85 and 5 (mid-way between the points
+ * where the classes change from A to B or vice versa).
+ */
+
+ /* initialize the hashtable */
+ // Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
+ // new Hashtable<Object, Hashtable<Object,
+ // Integer>>(Util.getDividingSize());
+ // for (Object attr : attributeValues) {
+ // facts_of_attribute.put(attr, new Hashtable<Object, Integer>(
+ // targetValues.size() + 1));
+ // for (Object t : targetValues) {
+ // facts_of_attribute.get(attr).put(t, 0);
+ // }
+ // facts_of_attribute.get(attr).put(attr_sum, 0);
+ // }
+ //
+ int split_index = 0;
+ Iterator<Integer> split_ite = split_facts.iterator();
+ int f1_index = split_ite.next().intValue();
+ Fact f1 = facts.get(f1_index);
+ while (split_ite.hasNext()) {
+ int f2_index = f1_index + 1;
+ Fact f2 = facts.get(f2_index);
+
+ if (f1.getFieldValue(targetAttr) == f2.getFieldValue(targetAttr)) {
+ // the cut point
+ System.out
+ .println("Bok i have splited what the fuck is happening f1:"
+ + f1 + " f2:" + f2);
+ System.exit(0);
+
+ }
+ Number cp_i = (Number) f1.getFieldValue(attributeToSplit);
+ Number cp_i_next = (Number) f2.getFieldValue(attributeToSplit);
+
+ Object cut_point = (cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
+ // calculate the gain of the cut point
+
+ /*
+ * 3. Evaluate your favourite disparity measure (info gain, gain
+ * ratio, gini coefficient, chi-squared test) on each of the
+ * cutpoints, and choose the one with the maximum value (I think
+ * Fayyad and Irani used info gain).
+ */
+ // double sum = 0.0;
+ // //for (Object attr : attributeValues) {
+ // for (int i = 1; i<2; i++) {
+ //
+ // int total_num_attr =
+ // facts_of_attribute.get(attr).get(attr_sum).intValue();
+ //
+ // double sum_attr = 0.0;
+ // if (total_num_attr > 0)
+ // for (Object t : targetValues) {
+ // int num_attr_target =
+ // facts_of_attribute.get(attr).get(t).intValue();
+ //
+ // double prob = (double) num_attr_target/ total_num_attr;
+ // // System.out.println("prob "+ prob);
+ // sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util.log2(prob));
+ // }
+ // sum += ((double) total_num_attr / (double) total_num_facts)*
+ // sum_attr;
+ // }
+ // getContinuousGain(facts, split_facts.subList(fromIndex,
+ // centerIndex), begin_index, middle_index,
+ // facts_in_class1, attributeToSplit);
+ //
+ // getContinuousGain(facts, split_facts.subList(centerIndex,
+ // toIndex), middle_index+1, end_index,
+ // facts_in_class2, attributeToSplit);
+ f1_index = split_ite.next().intValue();
+ f1 = facts.get(f1_index);
+ }
+
List<?> targetValues = getPossibleValues(target);
- //Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(targetValues.size());
+ // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
+ // Integer>(targetValues.size());
return 1.0;
}
-
-
-//*OPT* public double getInformation(List<FactSet> facts) {
- Hashtable<Object, Integer> getStatistics(List<Fact> facts, String target, List<?> targetValues) {
- Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(targetValues.size());
- for (Object t: targetValues) {
+ // *OPT* public double getInformation(List<FactSet> facts) {
+ Hashtable<Object, Integer> getStatistics(List<Fact> facts, String target) {
+
+ List<?> targetValues = getPossibleValues(this.target);
+ Hashtable<Object, Integer> facts_in_class = new Hashtable<Object, Integer>(
+ targetValues.size());
+
+ for (Object t : targetValues) {
facts_in_class.put(t, 0);
}
- int total_num_facts= 0;
-//*OPT* for (FactSet fs: facts) {
-//*OPT* for (Fact f: fs.getFacts()) {
- for (Fact f: facts) {
- total_num_facts++;
- Object key = f.getFieldValue(target);
- //System.out.println("My key: "+ key.toString());
- facts_in_class.put(key, facts_in_class.get(key).intValue() + 1); // bocuk kafa :P
+ int total_num_facts = 0;
+ // *OPT* for (FactSet fs: facts) {
+ // *OPT* for (Fact f: fs.getFacts()) {
+ for (Fact f : facts) {
+ total_num_facts++;
+ Object key = f.getFieldValue(target);
+ // System.out.println("My key: "+ key.toString());
+ facts_in_class.put(key, facts_in_class.get(key).intValue() + 1); // bocuk
+ // kafa
+ // :P
}
FACTS_READ += facts.size();
-//*OPT* }
-//*OPT* }
+ // *OPT* }
+ // *OPT* }
return facts_in_class;
}
+ // *OPT* public double getInformation(List<FactSet> facts) {
+ /**
+ * it returns the information value of facts entropy that characterizes the
+ * (im)purity of an arbitrary collection of examples
+ *
+ * @param facts
+ * list of facts
+ */
+ public double getInformation_old(List<Fact> facts) {
-//*OPT* public double getInformation(List<FactSet> facts) {
- /** it returns the information value of facts
- * entropy that characterizes the (im)purity of an arbitrary collection of examples
- * @param facts list of facts
- */
- public double getInformation(List<Fact> facts) {
-
List<?> targetValues = getPossibleValues(this.target);
-
- Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target, targetValues);
+ Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
+ getTarget()); // , targetValues)
+ // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
+ // getTarget(), targetValues);
int total_num_facts = facts.size();
double sum = 0;
- for (Object key: targetValues) {
+ for (Object key : targetValues) {
int num_in_class = facts_in_class.get(key).intValue();
- //System.out.println("num_in_class : "+ num_in_class + " key "+ key + " and the total num "+ total_num_facts);
+ // System.out.println("num_in_class : "+ num_in_class + " key "+ key
+ // + " and the total num "+ total_num_facts);
double prob = (double) num_in_class / (double) total_num_facts;
-
- //double log2= Util.log2(prob);
- //double plog2p= prob*log2;
- sum += (prob == 0.0) ? 0.0 :-1* prob * Util.log2(prob);
- //System.out.println("prob "+ prob +" and the plog(p)"+plog2p+" where the sum: "+sum);
+
+ // double log2= Util.log2(prob);
+ // double plog2p= prob*log2;
+ sum += (prob == 0.0) ? 0.0 : -1 * prob * Util.log2(prob);
+ // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
+ // where the sum: "+sum);
}
return sum;
}
+ public double getAttrInformation( Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute, int fact_size) {
+
+ Collection<Object> attributeValues = facts_of_attribute.keySet();
+ String attr_sum = Util.getSum();
+ double sum = 0.0;
+ for (Object attr : attributeValues) {
+ int total_num_attr = facts_of_attribute.get(attr).get(attr_sum).intValue();
+ //double sum_attr = 0.0;
+ if (total_num_attr > 0) {
+ sum += ((double) total_num_attr / (double) fact_size)*
+ getInformation(facts_of_attribute.get(attr), total_num_attr);
+ }
+ }
+ return sum;
+ }
+
+ public double getInformation(Hashtable<Object, Integer> facts_in_class, int total_num_facts) {
+
+ // List<?> targetValues = getPossibleValues(this.target);
+ // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
+ // getTarget()); //, targetValues);
+ Collection<Object> targetValues = facts_in_class.keySet();
+ double sum = 0;
+ for (Object key : targetValues) {
+ int num_in_class = facts_in_class.get(key).intValue();
+ // System.out.println("num_in_class : "+ num_in_class + " key "+ key
+ // + " and the total num "+ total_num_facts);
+ double prob = (double) num_in_class / (double) total_num_facts;
+
+ // double log2= Util.log2(prob);
+ // double plog2p= prob*log2;
+ sum += (prob == 0.0) ? 0.0 : -1 * prob * Util.log2(prob);
+ // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
+ // where the sum: "+sum);
+ }
+ return sum;
+ }
+
public void setTarget(String targetField) {
target = targetField;
attrsToClassify.remove(target);
@@ -236,13 +549,13 @@
domainSet.put(domain.getName(), domain);
if (!domain.getName().equals(this.target))
attrsToClassify.add(domain.getName());
-
+
}
public List<?> getPossibleValues(String fieldName) {
return domainSet.get(fieldName).getValues();
}
-
+
public List<String> getAttributes() {
return attrsToClassify;
}
@@ -250,46 +563,40 @@
public String getTarget() {
return target;
}
-
+
public String getName() {
return className;
}
-
public Domain<?> getDomain(String key) {
return domainSet.get(key);
}
-
public TreeNode getRoot() {
- return(root);
-
+ return (root);
+
}
-
+
public void setRoot(TreeNode root) {
this.root = root;
-
+
}
-
+
public long getNumRead() {
return FACTS_READ;
}
+
@Override
public String toString() {
return "Facts scanned " + FACTS_READ + "\n" + root.toString();
}
-
-
- /* **OPT
- int getTotalSize(List<FactSet> facts) {
- int num = 0;
- for(FactSet fs : facts) {
- num += fs.getSize();
- }
+ /*
+ * **OPT int getTotalSize(List<FactSet> facts) {
+ *
+ * int num = 0; for(FactSet fs : facts) { num += fs.getSize(); }
+ *
+ * return num; }
+ */
- return num;
- }
- */
-
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilder.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilder.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -4,7 +4,9 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.Hashtable;
+import java.util.Iterator;
import java.util.List;
public class DecisionTreeBuilder {
@@ -44,7 +46,9 @@
// **OPT List<FactSet> facts = new ArrayList<FactSet>();
ArrayList<Fact> facts = new ArrayList<Fact>();
FactSet klass_fs = null;
- for (FactSet fs: wm.getFactsets()) {
+ Iterator<FactSet> it_fs= wm.getFactsets();
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
if (fs instanceof OOFactSet) {
if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
// **OPT facts.add(fs);
@@ -85,7 +89,9 @@
// **OPT List<FactSet> facts = new ArrayList<FactSet>();
ArrayList<Fact> facts = new ArrayList<Fact>();
FactSet klass_fs = null;
- for (FactSet fs: wm.getFactsets()) {
+ Iterator<FactSet> it_fs= wm.getFactsets();
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
if (klass == fs.getClassName()) {
// **OPT facts.add(fs);
fs.assignTo(facts); // adding all facts of fs to "facts"
@@ -129,9 +135,88 @@
throw new RuntimeException("Nothing to classify, factlist is empty");
}
/* let's get the statistics of the results */
- List<?> targetValues = dt.getPossibleValues(dt.getTarget());
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget(), targetValues);
+ //List<?> targetValues = dt.getPossibleValues(dt.getTarget());
+ Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
+ Collection<Object> targetValues = stats.keySet();
+ int winner_vote = 0;
+ int num_supporters = 0;
+ Object winner = null;
+ for (Object key: targetValues) {
+ int num_in_class = stats.get(key).intValue();
+ if (num_in_class>0)
+ num_supporters ++;
+ if (num_in_class > winner_vote) {
+ winner_vote = num_in_class;
+ winner = key;
+ }
+ }
+
+ /* if all elements are classified to the same value */
+ if (num_supporters == 1) {
+ //*OPT* return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
+ LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
+ return classifiedNode;
+ }
+
+ /* if there is no attribute left in order to continue */
+ if (attributeNames.size() == 0) {
+ /* an heuristic of the leaf classification*/
+ LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
+ return noAttributeLeftNode;
+ }
+
+ /* id3 starts */
+ String chosenAttribute = attributeWithGreatestGain_discrete(dt, facts, stats, attributeNames);
+
+ System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
+
+ TreeNode currentNode = new TreeNode(dt.getDomain(chosenAttribute));
+ //ConstantDecisionTree m = majorityValue(ds);
+ /* the majority */
+
+ List<?> attributeValues = dt.getPossibleValues(chosenAttribute);
+ Hashtable<Object, List<Fact> > filtered_facts = splitFacts(facts, chosenAttribute, attributeValues);
+ dt.FACTS_READ += facts.size();
+
+
+// if (FUNC_CALL ==5) {
+// System.out.println("FUNC_CALL:" +FUNC_CALL);
+// System.exit(0);
+// }
+ for (int i = 0; i < attributeValues.size(); i++) {
+ /* split the last two class at the same time */
+ Object value = attributeValues.get(i);
+
+ ArrayList<String> attributeNames_copy = new ArrayList<String>(attributeNames);
+ attributeNames_copy.remove(chosenAttribute);
+
+ if (filtered_facts.get(value).isEmpty()) {
+ /* majority !!!! */
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ majorityNode.setRank(0.0);
+ currentNode.addNode(value, majorityNode);
+ } else {
+ TreeNode newNode = id3(dt, filtered_facts.get(value), attributeNames_copy);
+ currentNode.addNode(value, newNode);
+ }
+ }
+
+ return currentNode;
+ }
+
+private TreeNode c4_5(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
+
+ FUNC_CALL ++;
+ if (facts.size() == 0) {
+ throw new RuntimeException("Nothing to classify, factlist is empty");
+ }
+ /* let's get the statistics of the results */
+ //List<?> targetValues = dt.getPossibleValues(dt.getTarget());
+ Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
+ Collection<Object> targetValues = stats.keySet();
int winner_vote = 0;
int num_supporters = 0;
Object winner = null;
@@ -163,7 +248,7 @@
}
/* id3 starts */
- String chosenAttribute = attributeWithGreatestGain(dt, facts, attributeNames);
+ String chosenAttribute = attributeWithGreatestGain(dt, facts, stats, attributeNames);
System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
@@ -202,9 +287,10 @@
}
//String chooseAttribute(List<FactSet> facts, List<String> attrs) {
- public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts, List<String> attrs) {
+ public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts,
+ Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
- double dt_info = dt.getInformation(facts);
+ double dt_info = dt.getInformation(facts_in_class, facts.size());
double greatestGain = 0.0;
String attributeWithGreatestGain = attrs.get(0);
for (String attr : attrs) {
@@ -212,7 +298,15 @@
if (dt.getDomain(attr).isDiscrete()) {
gain = dt_info - dt.getGain(facts, attr);
} else {
- gain = dt_info - dt.getContinuousGain(facts, attr);
+ /* 1. sort the values */
+ int begin_index = 0;
+ int end_index = facts.size();
+ Collections.sort(facts, new FactNumericAttributeComparator(attr));
+ List<Integer> splits = getSplitPoints(facts, dt.getTarget());
+ gain = dt_info - dt.getContinuousGain(facts, splits,
+ begin_index, end_index,
+ facts_in_class, attr);
+ //gain = dt_info - dt.getContinuousGain(facts, facts_in_class, attr);
}
System.out.println("Attribute: "+attr +" the gain: "+gain);
@@ -224,7 +318,53 @@
return attributeWithGreatestGain;
}
+ /*
+ * id3 uses that function because it can not classify continuous attributes
+ */
+
+ public String attributeWithGreatestGain_discrete(DecisionTree dt, List<Fact> facts,
+ Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+
+ double dt_info = dt.getInformation(facts_in_class, facts.size());
+ double greatestGain = 0.0;
+ String attributeWithGreatestGain = attrs.get(0);
+ for (String attr : attrs) {
+ double gain = 0;
+ if (!dt.getDomain(attr).isDiscrete()) {
+ System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
+ continue;
+ } else {
+ gain = dt_info - dt.getGain(facts, attr);
+ }
+ System.out.println("Attribute: " + attr + " the gain: " + gain);
+ if (gain > greatestGain) {
+ greatestGain = gain;
+ attributeWithGreatestGain = attr;
+ }
+
+
+ }
+
+ return attributeWithGreatestGain;
+ }
+ private List<Integer> getSplitPoints(List<Fact> facts, String target) {
+ List<Integer> splits = new ArrayList<Integer>();
+ Iterator<Fact> it_f = facts.iterator();
+ Fact f1 = it_f.next();
+ int index = 0;
+ while(it_f.hasNext()){
+ Fact f2 = it_f.next();
+ if (f1.getFieldValue(target) != f2.getFieldValue(target))
+ splits.add(Integer.valueOf(index));
+
+ f1= f2;
+ index++;
+ }
+ return splits;
+ }
+
+
public Hashtable<Object, List<Fact> > splitFacts(List<Fact> facts, String attributeName,
List<?> attributeValues) {
Hashtable<Object, List<Fact> > factLists = new Hashtable<Object, List<Fact> >(attributeValues.size());
@@ -238,11 +378,12 @@
}
public void testEntropy(DecisionTree dt, List<Fact> facts) {
- double initial_info = dt.getInformation(facts); //entropy value
+ Hashtable<Object, Integer> facts_in_class = dt.getStatistics(facts, dt.getTarget());//, targetValues
+ double initial_info = dt.getInformation(facts_in_class, facts.size()); //entropy value
System.out.println("initial_information: "+ initial_info);
- String first_attr = attributeWithGreatestGain(dt, facts, dt.getAttributes());
+ String first_attr = attributeWithGreatestGain(dt, facts, facts_in_class, dt.getAttributes());
System.out.println("best attr: "+ first_attr);
}
@@ -250,5 +391,24 @@
public int getNumCall() {
return FUNC_CALL;
}
+
+ private class FactNumericAttributeComparator implements Comparator<Fact> {
+ private String attr_name;
+ public FactNumericAttributeComparator(String _attr_name) {
+ attr_name = _attr_name;
+ }
+
+ public int compare(Fact f0, Fact f1) {
+ Number n0 = (Number) f0.getFieldValue(attr_name);
+ Number n1 = (Number) f1.getFieldValue(attr_name);
+ if (n0.doubleValue() < n1.doubleValue())
+ return -1;
+ else if (n0.doubleValue() > n1.doubleValue())
+ return 1;
+ else
+ return 0;
+ }
+ }
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilderMT.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilderMT.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DecisionTreeBuilderMT.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -5,6 +5,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.Hashtable;
+import java.util.Iterator;
import java.util.List;
public class DecisionTreeBuilderMT {
@@ -44,7 +45,9 @@
// **OPT List<FactSet> facts = new ArrayList<FactSet>();
ArrayList<Fact> facts = new ArrayList<Fact>();
FactSet klass_fs = null;
- for (FactSet fs: wm.getFactsets()) {
+ Iterator<FactSet> it_fs= wm.getFactsets();
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
if (fs instanceof OOFactSet) {
if (klass.isAssignableFrom(((OOFactSet)fs).getFactClass())) {
// **OPT facts.add(fs);
@@ -96,7 +99,9 @@
// **OPT List<FactSet> facts = new ArrayList<FactSet>();
ArrayList<Fact> facts = new ArrayList<Fact>();
FactSet klass_fs = null;
- for (FactSet fs: wm.getFactsets()) {
+ Iterator<FactSet> it_fs= wm.getFactsets();
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
if (klass == fs.getClassName()) {
// **OPT facts.add(fs);
fs.assignTo(facts); // adding all facts of fs to "facts"
@@ -178,7 +183,7 @@
}
/* let's get the statistics of the results */
List<?> targetValues = dt.getPossibleValues(dt.getTarget());
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget(), targetValues);
+ Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//,targetValues
int winner_vote = 0;
int num_supporters = 0;
@@ -211,7 +216,7 @@
}
/* id3 starts */
- String chosenAttribute = attributeWithGreatestGain(dt, facts, attributeNames);
+ String chosenAttribute = attributeWithGreatestGain(dt, facts, stats, attributeNames);
System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
@@ -264,9 +269,9 @@
}
//String chooseAttribute(List<FactSet> facts, List<String> attrs) {
- public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts, List<String> attrs) {
+ public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts, Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
- double dt_info = dt.getInformation(facts);
+ double dt_info = dt.getInformation(facts_in_class, facts.size());
double greatestGain = 0.0;
String attributeWithGreatestGain = attrs.get(0);
for (String attr : attrs) {
@@ -294,11 +299,13 @@
}
public void testEntropy(DecisionTree dt, List<Fact> facts) {
- double initial_info = dt.getInformation(facts); //entropy value
+ Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());
+
+ double initial_info = dt.getInformation(stats, facts.size()); //entropy value
System.out.println("initial_information: "+ initial_info);
- String first_attr = attributeWithGreatestGain(dt, facts, dt.getAttributes());
+ String first_attr = attributeWithGreatestGain(dt, facts, stats, dt.getAttributes());
System.out.println("best attr: "+ first_attr);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Domain.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Domain.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -8,7 +8,7 @@
void setConstant();
boolean isDiscrete();
- //void setConstant();
+ void setDiscrete(boolean disc);
boolean contains(T value);
@@ -22,6 +22,9 @@
String toString();
boolean isPossible(Object value) throws Exception;
+
+ void setReadingSeq(int readingSeq);
+ int getReadingSeq();
}
Copied: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DomainSpec.java (from rev 19268, labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ReadingSeq.java)
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DomainSpec.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/DomainSpec.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -0,0 +1,12 @@
+package id3;
+
+import java.lang.annotation.*;
+
+ at Retention(RetentionPolicy.RUNTIME)
+ at Target({ElementType.METHOD, ElementType.FIELD})
+public @interface DomainSpec {
+ int readingSeq();
+ boolean target() default false;
+ boolean discrete() default true;
+ String[] values() default {"bok"};
+}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Fact.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Fact.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Fact.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,4 +1,5 @@
package id3;
+import java.util.Comparator;
import java.util.Hashtable;
import java.util.Set;
@@ -72,4 +73,5 @@
}
return out;
}
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/FactSetFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/FactSetFactory.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/FactSetFactory.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -142,13 +142,13 @@
ArrayList<Domain<?>> domains = new ArrayList<Domain<?>>();
NumericDomain height = new NumericDomain("height");
- height.setContinuous();
+ height.setDiscrete(false);
NumericDomain width = new NumericDomain("width");
- height.setContinuous();
+ height.setDiscrete(false);
NumericDomain aratio = new NumericDomain("aratio");
- height.setContinuous();
+ height.setDiscrete(false);
domains.add(height);
domains.add(width);
domains.add(aratio);
@@ -236,13 +236,14 @@
return false;
}
- public static void fromFileAsObject(WorkingMemory wm, Class<?> klass, String filename, String separator)
+ public static List<Object> fromFileAsObject(WorkingMemory wm, Class<?> klass, String filename, String separator)
throws IOException {
+ List<Object> obj_read = new ArrayList<Object>();
OOFactSet fs = wm.getFactSet(klass);
Collection<Domain<?>> domains = fs.getDomains();
BufferedReader reader = new BufferedReader(new InputStreamReader(
- FactSetFactory.class.getResourceAsStream(filename)));// "../data/"
+ klass.getResourceAsStream(filename)));// "../data/"
// +
String line;
while ((line = reader.readLine()) != null) {
@@ -255,9 +256,11 @@
break;
Object element = ObjectReader.read(klass, domains, line, separator);
//System.out.println("New object "+ element);
+ obj_read.add(element);
fs.insert(element);
+
}
- return;
+ return obj_read;
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/LiteralDomain.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/LiteralDomain.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -9,28 +9,29 @@
private String fName;
private List<String> fValues;
private boolean constant;
- //private boolean discrete;
+ private boolean discrete;
+ private int readingSeq;
public LiteralDomain(String _name) {
fName = _name.trim();
fValues = new ArrayList<String>();
- //discrete = true;
+ discrete = true;
}
public LiteralDomain(String _name, String[] possibleValues) {
fName = _name;
fValues = Arrays.asList(possibleValues);
- //discrete = true;
+ discrete = true;
}
-// public void setContinuous() {
-// discrete = false;
-// }
+ public void setDiscrete(boolean d) {
+ this.discrete = d;
+ }
public boolean isDiscrete() {
- return true;
+ return this.discrete;
}
public String getName() {
@@ -86,6 +87,16 @@
return true;
}
+ public void setReadingSeq(int readingSeq) {
+ this.readingSeq = readingSeq;
+
+ }
+
+ public int getReadingSeq() {
+ return this.readingSeq;
+
+ }
+
public String toString() {
String out = fName;
return out;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/NumericDomain.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/NumericDomain.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -9,6 +9,7 @@
private ArrayList<Number> fValues;
private boolean constant;
private boolean discrete;
+ private int readingSeq;
public NumericDomain(String _name) {
@@ -16,8 +17,8 @@
fValues = new ArrayList<Number>();
discrete = true;
}
- public void setContinuous() {
- discrete = false;
+ public void setDiscrete(boolean d) {
+ this.discrete = d;
}
public boolean isDiscrete() {
@@ -143,6 +144,16 @@
return false;
}
+ public void setReadingSeq(int readingSeq) {
+ this.readingSeq = readingSeq;
+
+ }
+
+ public int getReadingSeq() {
+ return this.readingSeq;
+
+ }
+
public String toString() {
String out = fName;
return out;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ObjectReader.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ObjectReader.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ObjectReader.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -17,8 +17,148 @@
private static final boolean DEBUG = false;
+ public static Object read(Class<?> element_class,
+ Collection<Domain<?>> domains, String data, String separator) {
+
+ // System.out.println("BOK BOK domains: "+ domains.size());
+ Object element = null;
+ try {
+ // element = Class.forName(element_class.getName());
+
+ element = element_class.newInstance();
+
+ Method[] element_methods = element_class.getDeclaredMethods();
+
+ if (data.endsWith("."))
+ data = data.substring(0, data.length() - 1);
+ List<String> attributeValues = Arrays.asList(data.split(separator));
+
+ for (Method m : element_methods) {
+ String m_name = m.getName();
+ Class<?>[] param_type_name = m.getParameterTypes();
+ if (Util.isSetter(m_name) & Util.isSimpleType(param_type_name)) {
+ // if (!Util.isSimpleType(return_type_name))
+ // continue; // in the future we should support classes
+ /*
+ * Annotation[] annotations = m.getAnnotations();
+ * // iterate over the annotations to locate the MaxLength
+ * constraint if it exists DomainSpec spec = null; for
+ * (Annotation a : annotations) { if (a instanceof
+ * DomainSpec) { spec = (DomainSpec)a; // here it is !!!
+ * break; } } if (DEBUG) System.out.println("What annotation
+ * i found: "+ spec + " for method "+ m); String fieldString =
+ * attributeValues.get(spec.readingSeq());
+ *
+ */
+
+ String field = Util.getAttributeName(m_name);
+
+ Iterator<Domain<?>> domain_it = domains.iterator();
+ // Iterator<String> value_it = attributeValues.iterator();
+ while (domain_it.hasNext()) {
+ Domain<?> attr_domain = domain_it.next();
+ // String name = attr_domain.getName();
+ if (field.equalsIgnoreCase(attr_domain.getName())) {
+
+ String fieldString = attributeValues
+ .get(attr_domain.getReadingSeq());
+ Object fieldValue = attr_domain
+ .readString(fieldString);
+
+ if (attr_domain instanceof NumericDomain) {
+ if (param_type_name[0].getName()
+ .equalsIgnoreCase("int")) {
+ fieldValue = ((Number) fieldValue)
+ .intValue();
+
+ } else if (param_type_name[0].getName()
+ .equalsIgnoreCase("float")) {
+ fieldValue = ((Number) fieldValue)
+ .floatValue();
+
+ } else if (!param_type_name[0].getName()
+ .equalsIgnoreCase("double")) {
+ System.out
+ .println("What the hack, which type of number is this??");
+ fieldValue = ((Number) fieldValue)
+ .doubleValue();
+ System.exit(0);
+ }
+ } else if (attr_domain instanceof LiteralDomain) {
+ if (param_type_name[0].getName()
+ .equalsIgnoreCase("java.lang.String")) {
+ } else {
+ System.out
+ .println("What the hack, which type of string is this?? "
+ + fieldValue);
+ System.exit(0);
+ }
+ } else if (attr_domain instanceof BooleanDomain) {
+ if (param_type_name[0].getName()
+ .equalsIgnoreCase("boolean")) {
+ } else {
+ System.out
+ .println("What the hack, which type of boolean is this?? "
+ + fieldValue);
+ System.exit(0);
+ }
+ } else {
+ System.out
+ .println("What the hack, which type of object is this?? "
+ + fieldValue);
+ System.exit(0);
+ }
+
+ // String fieldValue = fieldString;
+
+ try {
+
+ if (DEBUG)
+ System.out.println("ObjectReader.read obj "
+ + element.getClass()
+ + " fielddomain name "
+ + attr_domain.getName()
+ + " value: " + fieldValue);
+ if (DEBUG)
+ System.out
+ .println("ObjectReader.read method "
+ + m
+ + " the parameter type:"
+ + fieldValue.getClass());
+ m.invoke(element, fieldValue);
+
+ } catch (IllegalArgumentException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (IllegalAccessException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (InvocationTargetException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ break;
+ }
+
+ }
+ }
+ }
+
+ } catch (InstantiationException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (IllegalAccessException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ return element;
+
+ }
//read(Class<?> element_class, Collection<Domain<?>> collection, String data, String separator)
- public static Object read(Class<?> element_class, Collection<Domain<?>> domains, String data, String separator) {
+ public static Object read_(Class<?> element_class, Collection<Domain<?>> domains, String data, String separator) {
Object element= null;
try {
@@ -48,15 +188,15 @@
Annotation[] annotations = m.getAnnotations();
// iterate over the annotations to locate the MaxLength constraint if it exists
- ReadingSeq sequence = null;
+ DomainSpec spec = null;
for (Annotation a : annotations) {
- if (a instanceof ReadingSeq) {
- sequence = (ReadingSeq)a; // here it is !!!
+ if (a instanceof DomainSpec) {
+ spec = (DomainSpec)a; // here it is !!!
break;
}
}
- if (DEBUG) System.out.println("What annotation i found: "+ sequence + " for method "+ m);
- String fieldString = attributeValues.get(sequence.value());
+ if (DEBUG) System.out.println("What annotation i found: "+ spec + " for method "+ m);
+ String fieldString = attributeValues.get(spec.readingSeq());
String field = Util.getAttributeName(m_name);
Iterator<Domain<?>> domain_it = domains.iterator();
@@ -65,9 +205,41 @@
Domain<?> attr_domain = domain_it.next();
//String name = attr_domain.getName();
if (field.equalsIgnoreCase(attr_domain.getName())) {
- //String fieldValue = attr_domain.readString(fieldString);
- String fieldValue = fieldString;
+ Object fieldValue = attr_domain.readString(fieldString);
+
+ if (attr_domain instanceof NumericDomain) {
+ if (param_type_name[0].getName().equalsIgnoreCase("int")) {
+ fieldValue = ((Number)fieldValue).intValue();
+
+ } else if (param_type_name[0].getName().equalsIgnoreCase("float")) {
+ fieldValue = ((Number)fieldValue).floatValue();
+
+ } else if (!param_type_name[0].getName().equalsIgnoreCase("double")) {
+ System.out.println("What the hack, which type of number is this??");
+ fieldValue = ((Number)fieldValue).doubleValue();
+ System.exit(0);
+ }
+ } else if (attr_domain instanceof LiteralDomain) {
+ if (param_type_name[0].getName().equalsIgnoreCase("java.lang.String")) {
+ } else {
+ System.out.println("What the hack, which type of string is this?? " + fieldValue);
+ System.exit(0);
+ }
+ } else if (attr_domain instanceof BooleanDomain) {
+ if (param_type_name[0].getName().equalsIgnoreCase("boolean")) {
+ } else {
+ System.out.println("What the hack, which type of boolean is this?? " + fieldValue);
+ System.exit(0);
+ }
+ } else {
+ System.out.println("What the hack, which type of object is this?? " + fieldValue);
+ System.exit(0);
+ }
+
+
+ // String fieldValue = fieldString;
+
try {
if (DEBUG) System.out.println("ObjectReader.read obj "+ element.getClass() + " fielddomain name "+attr_domain.getName()+" value: "+fieldValue);
Deleted: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ReadingSeq.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ReadingSeq.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/ReadingSeq.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,9 +0,0 @@
-package id3;
-
-import java.lang.annotation.*;
-
- at Retention(RetentionPolicy.RUNTIME)
- at Target(ElementType.METHOD)
-public @interface ReadingSeq {
- int value();
-}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Util.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/Util.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,5 +1,10 @@
package id3;
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Field;
+import java.util.Hashtable;
+import java.util.List;
+
public class Util {
public static String ntimes(String s,int n){
@@ -22,6 +27,24 @@
else
return false;
}
+
+// public static boolean castType(Object value, Class<?>[] type_name) {
+//// simpletype.contains(type_name)
+// if (type_name.length!=1)
+// return false;
+//
+// if (type_name[0].getName().equalsIgnoreCase("boolean")) {
+//
+// } else if (type_name[0].getName().equalsIgnoreCase("int")) {
+//
+// } else if (type_name[0].getName().equalsIgnoreCase("double")) {
+// } else if (type_name[0].getName().equalsIgnoreCase("float")){
+//
+// } else if (type_name[0].getName().equalsIgnoreCase("java.lang.String")){
+// return true;
+// }else
+// return false;
+// }
public static boolean isGetter(String method_name) {
if (method_name.startsWith("get") || method_name.startsWith("is") )
@@ -47,6 +70,38 @@
return Math.log(prob) / Math.log(2);
}
+ public static int getDividingSize() {
+ return 2;
+ }
+ public static String getTargetAnnotation(Class<? extends Object> classObj) {
+
+ Field [] element_fields = classObj.getDeclaredFields();
+ for( Field f: element_fields) {
+ String f_name = f.getName();
+ Class<?>[] f_class = {f.getType()};
+ if (Util.isSimpleType(f_class)) {
+ Annotation[] annotations = f.getAnnotations();
+
+ // iterate over the annotations to locate the MaxLength constraint if it exists
+ DomainSpec spec = null;
+ for (Annotation a : annotations) {
+ if (a instanceof DomainSpec) {
+ spec = (DomainSpec)a; // here it is !!!
+ if (spec.target())
+ return f_name;
+ }
+ }
+ }
+ }
+ return null;
+ }
+
+ public static String getSum() {
+ return "sum";
+ }
+
+
+
}
\ No newline at end of file
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/WorkingMemory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/WorkingMemory.java 2008-03-29 20:03:39 UTC (rev 19314)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/id3/WorkingMemory.java 2008-03-30 01:10:03 UTC (rev 19315)
@@ -1,8 +1,10 @@
package id3;
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Field;
import java.lang.reflect.Method;
-import java.util.Collection;
import java.util.Hashtable;
+import java.util.Iterator;
import java.util.List;
public class WorkingMemory {
@@ -79,15 +81,58 @@
* newfs.adddomain(d)=> why do you add this the factset?
* we said that the domains should be independent from the factset
*/
+
private OOFactSet create_factset(Class<?> classObj) {
//System.out.println("WorkingMemory.create_factset element "+ element );
OOFactSet newfs = new OOFactSet(classObj);
+
+ Field [] element_fields = classObj.getDeclaredFields();
+ for( Field f: element_fields) {
+ String f_name = f.getName();
+ Class<?>[] f_class = {f.getType()};
+ System.out.println("WHat is this f: " +f.getType()+" the name "+f_name+" class "+ f.getClass() + " and the name"+ f.getClass().getName());
+ if (Util.isSimpleType(f_class)) {
+
+ Annotation[] annotations = f.getAnnotations();
+
+ // iterate over the annotations to locate the MaxLength constraint if it exists
+ DomainSpec spec = null;
+ for (Annotation a : annotations) {
+ if (a instanceof DomainSpec) {
+ spec = (DomainSpec)a; // here it is !!!
+ break;
+ }
+ }
+
+ Domain<?> fieldDomain;
+ if (!domainset.containsKey(f_name))
+ fieldDomain = DomainFactory.createDomainFromClass(f.getType(), f_name);
+ else
+ fieldDomain = domainset.get(f_name);
+
+ //System.out.println("WorkingMemory.create_factset field "+ field + " fielddomain name "+fieldDomain.getName()+" return_type_name: "+return_type_name+".");
+ if (spec != null) {
+ fieldDomain.setReadingSeq(spec.readingSeq());
+ fieldDomain.setDiscrete(spec.discrete());
+ }
+ domainset.put(f_name, fieldDomain);
+ newfs.addDomain(f_name, fieldDomain);
+
+ }
+ }
+ factsets.put(classObj.getName(), newfs);
+ return newfs;
+ }
+
+ private OOFactSet create_factset_(Class<?> classObj) {
+ //System.out.println("WorkingMemory.create_factset element "+ element );
+
+ OOFactSet newfs = new OOFactSet(classObj);
Method [] element_methods = classObj.getDeclaredMethods();
for( Method m: element_methods) {
-
-
+
String m_name = m.getName();
Class<?>[] returns = {m.getReturnType()};
//System.out.println("WorkingMemory.create_factset m "+ m + " method name "+m_name+" return_type_name: "+return_type_name+".");
@@ -101,6 +146,7 @@
* otherwise you create a new domain for that attribute
* Domain attributeSpec = dataSetSpec.getDomain(attr_name);
*/
+
Domain<?> fieldDomain;
if (!domainset.containsKey(field))
fieldDomain = DomainFactory.createDomainFromClass(m.getReturnType(), field);
@@ -121,9 +167,10 @@
return newfs;
}
- /* TODO: iterator */
- public Collection<FactSet> getFactsets() {
- return factsets.values();
+ /* TODO: is there a better way of doing this iterator? */
+ public Iterator<FactSet> getFactsets() {
+ return factsets.values().iterator();
+ //return factsets.values();
}
public Domain<?> getDomain(String field) {
More information about the jboss-svn-commits
mailing list