[jboss-svn-commits] JBL Code SVN: r19409 - in labs/jbossrules/contrib/machinelearning/decisiontree/src: dt/memory and 2 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Thu Apr 3 21:21:08 EDT 2008
Author: gizil
Date: 2008-04-03 21:21:08 -0400 (Thu, 03 Apr 2008)
New Revision: 19409
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
Log:
discretizing continuous by any number of branching factor (recursively)
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -145,7 +145,45 @@
return dt;
}
+
+ public DecisionTree build_test(WorkingMemory wm, String klass,
+ String targetField, List<String> workingAttributes) {
+ unclassified_facts = new ArrayList<Fact>();
+ DecisionTree dt = new DecisionTree(klass);
+ // **OPT List<FactSet> facts = new ArrayList<FactSet>();
+ ArrayList<Fact> facts = new ArrayList<Fact>();
+ FactSet klass_fs = null;
+ Iterator<FactSet> it_fs = wm.getFactsets();
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
+ if (klass == fs.getClassName()) {
+ // **OPT facts.add(fs);
+ fs.assignTo(facts); // adding all facts of fs to "facts"
+ klass_fs = fs;
+ break;
+ }
+ }
+ dt.FACTS_READ += facts.size();
+ setNum_fact_processed(facts.size());
+
+ if (workingAttributes != null)
+ for (String attr : workingAttributes) {
+ //System.out.println("Bok degil " + attr);
+ dt.addDomain(klass_fs.getDomain(attr));
+ }
+ else
+ for (Domain<?> d : klass_fs.getDomains())
+ dt.addDomain(d);
+
+ dt.setTarget(targetField);
+
+ ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
+ Collections.sort(attrs);
+
+ return dt;
+ }
+
private TreeNode c45(DecisionTree dt, List<Fact> facts,
List<String> attributeNames) {
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -0,0 +1,254 @@
+package dt.builder;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+import dt.memory.Domain;
+import dt.memory.Fact;
+import dt.memory.FactAttrDistribution;
+import dt.memory.FactTargetDistribution;
+import dt.tools.Util;
+
+public class Discretizer {
+ private Domain<?> splitDomain;
+ private Domain<?> targetDomain;
+
+ private String splitAttr;
+ private List<Fact> facts;
+ private FactTargetDistribution distribution;
+
+ private List<SplitPoint> split_indices;
+ private int maxDepth = 1;
+ private Domain binaryDomain;
+
+ Discretizer(Domain<?> _targetDomain, List<Fact> _facts, FactTargetDistribution _facts_in_class) {
+ this.facts = new ArrayList<Fact>(_facts.size());
+ facts.addAll(_facts);
+ // sort them
+
+ this.targetDomain = _targetDomain;
+ this.distribution = _facts_in_class;
+
+ }
+ public Fact getSortedFact(int i) {
+ return facts.get(i);
+ }
+
+ public List<Fact> getSortedFacts() {
+ return facts;
+ }
+
+ public void init_binary_domain() {
+ Object key0 = Integer.valueOf(0);
+ Object key1 = Integer.valueOf(1);
+
+ binaryDomain = splitDomain.clone();
+ binaryDomain.addPseudoValue(key0);
+ binaryDomain.addPseudoValue(key1);
+ }
+
+ public List<Integer> findSplits(Domain attrDomain) {
+
+ this.splitAttr = attrDomain.getName();
+ this.splitDomain = attrDomain;
+ init_binary_domain();
+ Collections.sort(facts, facts.get(0).getDomain(attrDomain.getName()).factComparator());
+
+ split_indices = new ArrayList<SplitPoint>();
+ SplitPoint last_point = new SplitPoint(facts.size()-1, (Number) facts.get(facts.size()-1).getFieldValue(splitAttr));
+ split_indices.add(last_point);
+
+ find_a_split(0, facts.size(), getMaxDepth(), distribution, split_indices);
+ Collections.sort(split_indices, Discretizer.getSplitComparator());
+
+ List<Integer> splits = new ArrayList<Integer>(split_indices.size());
+ for (SplitPoint sp: split_indices) {
+ splits.add(Integer.valueOf(sp.getIndex()));
+ attrDomain.addIndex(sp.getIndex());
+ attrDomain.addPseudoValue(sp.getCut_point());
+
+ }
+
+ return splits;
+
+ }
+
+
+ private int getMaxDepth() {
+ return this.maxDepth ;
+ }
+
+
+ public SplitPoint find_a_split(int begin_index, int end_index, int depth,
+ FactTargetDistribution facts_in_class,
+ List<SplitPoint> split_points) {
+
+ if (facts.size() <= 1) {
+ System.out.println("fact.size <=1 returning 0.0....");
+ return null;
+ }
+ facts_in_class.evaluateMajority();
+ if (facts_in_class.getNum_supported_target_classes()==1) {
+ System.out.println("getNum_supported_target_classes=1 returning 0.0....");
+ return null; //?
+ }
+
+ if (depth == 0) {
+ System.out.println("depth == 0 returning 0.0....");
+ return null;
+ }
+
+ String targetAttr = targetDomain.getName();
+ List<?> targetValues = targetDomain.getValues();
+ if (Util.DEBUG) {
+ System.out.println("Discretizer.find_a_split() attributeToSplit? " + splitAttr);
+ for(int index =begin_index; index < end_index; index ++) {
+ Fact f= facts.get(index);
+ System.out.println("entropy.info_cont() SORTING: "+index+" attr "+splitAttr+ " "+ f );
+ }
+ }
+ /* initialize the distribution */
+ Object key0 = Integer.valueOf(0);
+ Object key1 = Integer.valueOf(1);
+// List<Object> keys = new ArrayList<Object>(2);
+// keys.add(key0);
+// keys.add(key1);
+
+
+
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(binaryDomain, targetDomain);
+ facts_at_attribute.setTotal(facts.size());
+ facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
+ facts_at_attribute.setSumForAttr(key1, facts.size());
+
+ double best_sum = 100000.0;
+ SplitPoint bestPoint = split_points.get(split_points.size()-1);
+
+ int split_index =begin_index+1, index = begin_index+1;
+ FactAttrDistribution best_distribution = null;
+ //Iterator<Fact> f_ite = facts.iterator();
+ Iterator<Fact> f_ite = facts.listIterator(begin_index);
+
+ Fact f1 = f_ite.next();
+ Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
+ if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+begin_index+" attr "+splitAttr+ " "+ f1 );
+ while (f_ite.hasNext() && index<end_index) {/* 2. Look for potential cut-points. */
+
+ Fact f2 = f_ite.next();
+ if (Util.DEBUG) System.out.println("entropy.info_cont() SEARCHING: "+(index)+" attr "+splitAttr+ " "+ f2 );
+
+ Object targetKey = f1.getFieldValue(targetAttr);
+
+ // System.out.println("My key: "+ targetKey.toString());
+ //for (Object attr_key : attr_values)
+
+ /* every time it change the place in the distribution */
+ facts_at_attribute.change(key0, targetKey, +1);
+ facts_at_attribute.change(key1, targetKey, -1);
+
+ /*
+ * 2.1 Cut points are points in the sorted list above where the class labels change.
+ * Eg. if I had five instances with values for the attribute of interest and labels
+ * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
+ * two cutpoints of interest: 1.85 and 5 (mid-way between the points
+ * where the classes change from A to B or vice versa).
+ */
+
+ if ( targetComp.compare(f1, f2)!=0) {
+ // the cut point
+ Number cp_i = (Number) f1.getFieldValue(splitAttr);
+ Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
+
+ Number cut_point = (Double)(cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
+
+ /*
+ * 3. Evaluate your favourite disparity measure
+ * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
+ * and calculate its gain
+ */
+ double sum = Entropy.calc_info_attr(facts_at_attribute);
+ //System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum +
+ if (Util.DEBUG) System.out.println(" **Try @"+ (index)+" sum "+sum+" best sum "+best_sum +
+ " value ("+ f1.getFieldValue(splitAttr) +"-|"+ cut_point+"|-"+ f2.getFieldValue(splitAttr)+")");
+
+ if (sum < best_sum) {
+ best_sum = sum;
+
+ if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: @"+(index)+" target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
+ split_index = index;
+
+ bestPoint = new SplitPoint(index-1, cut_point);
+ bestPoint.setValue(best_sum);
+
+ if (best_distribution != null) best_distribution.clear();
+ best_distribution = new FactAttrDistribution(facts_at_attribute);
+ }
+ } else {}
+ f1 = f2;
+ index++;
+ }
+ split_points.add(bestPoint);
+
+ double sum1 = 0.0;
+
+ find_a_split(begin_index, split_index, depth-1,
+ best_distribution.getAttrFor(key0), split_points);
+
+ double sum2 = 0.0;
+ find_a_split(split_index, end_index, depth-1,
+ best_distribution.getAttrFor(key1), split_points);
+
+
+ if (Util.DEBUG) {
+ System.out.println("entropy.info_contattr(BOK_last) split_indices.size "+split_indices.size());
+ for(SplitPoint i : split_points)
+ System.out.println("all split_indices "+i.getIndex() + " the value "+ i.getCut_point());
+ }
+ return bestPoint;
+ }
+
+ public static Comparator<SplitPoint> getSplitComparator() {
+ return new SplitComparator();
+ }
+
+ private static class SplitComparator implements Comparator<SplitPoint>{
+ public int compare(SplitPoint sp1, SplitPoint sp2) {
+ return sp1.getIndex() - sp2.getIndex();
+
+ }
+ }
+
+ public class SplitPoint {
+
+ private int index;
+ private Number cut_point;
+ private double value;
+
+ SplitPoint(int _index, Number _point) {
+ this.index = _index;
+ this.cut_point = _point;
+ }
+ public void setValue(double info_sum) {
+ this.value = info_sum;
+ }
+ public int getIndex() {
+ return index;
+ }
+ public void setIndex(int index) {
+ this.index = index;
+ }
+ public Number getCut_point() {
+ return cut_point;
+ }
+ public void setCut_point(Number cut_point) {
+ this.cut_point = cut_point;
+ }
+ public double getValue() {
+ return value;
+ }
+
+ }
+}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -17,6 +17,7 @@
public class Entropy implements InformationMeasure {
+
public static Domain<?> chooseContAttribute(DecisionTree dt, List<Fact> facts,
FactDistribution facts_in_class, List<String> attrs) {
@@ -28,158 +29,75 @@
List<Integer> split_indices = null;
Domain<?> targetDomain = dt.getDomain(dt.getTarget());
for (String attr : attrs) {
+// if (attr.equalsIgnoreCase(targetDomain.getName()))
+// continue;
System.out.println("Which attribute to try: "+ attr);
double gain = 0;
if (dt.getDomain(attr).isDiscrete()) {
+ /* */
attrDomain = dt.getDomain(attr).clone();
for (Object v: dt.getDomain(attr).getValues())
attrDomain.addValue(v);
- gain = dt_info - info_attr(facts, attrDomain, targetDomain);
+ /* */
+ //attrDomain = dt.getDomain(attr);
+ gain = dt_info - info_attr(facts, dt.getDomain(attr), targetDomain);
} else {
/* 1. sort the values */
- Collections.sort(facts, facts.get(0).getDomain(attr).factComparator());
- List<Fact> splits = getSplitPoints(facts, dt.getTarget());
-
+ Discretizer visitor = new Discretizer(targetDomain, facts, facts_in_class);
attrDomain = dt.getDomain(attr).clone();
- attrDomain.addPseudoValue(facts.get(facts.size()-1).getFieldValue(attr));
-// System.out.println("entropy.chooseContAttribute(1)*********** num of split for "+
-// attr+": "+ attrDomain.getValues().size()+ " ("+ attrDomain.getValues().get(0)+")");
- split_indices = new ArrayList<Integer>();
- //System.out.println("entropy.chooseContAttribute(BOK) size "+split_indices.size());
- gain = dt_info - info_contattr(facts, attrDomain, targetDomain,
- facts_in_class, split_indices, splits);
-// System.out.println("entropy.chooseContAttribute(2)*********** num of split for "+
-// attr+": "+ attrDomain.getValues().size());
+ split_indices = visitor.findSplits(attrDomain);
+ int index = 0;
+ for (Integer i: split_indices) {
+ System.out.print("Split indices:"+ i);
+ System.out.print(" domain "+attrDomain.getValues().get(index));
+ System.out.print(","+attrDomain.getIndices().get(index));
+ System.out.println(" fact "+visitor.getSortedFact(i));
+ index++;
+ }
+ gain = dt_info - calc_info_contattr(visitor.getSortedFacts(), attrDomain, targetDomain, split_indices);
+
}
+ System.out.println(Util.ntimes("\n",3)+Util.ntimes("?",10)+" ATTR TRIAL "+attr + " the gain "+gain + " info "+ dt_info );
+
if (gain > greatestGain) {
greatestGain = gain;
attributeWithGreatestGain = attr;
bestDomain = attrDomain;
- if (!bestDomain.isDiscrete())
- bestDomain.setIndices(split_indices);
+// if (!bestDomain.isDiscrete())
+// bestDomain.setIndices(split_indices);
+
+ System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
}
}
return bestDomain;
}
- public static double info_contattr(List<Fact> facts,
+
+
+ public static double info_contattr_rec(List<Fact> facts, int begin_index, int end_index, int depth,
Domain splitDomain, Domain<?> targetDomain,
FactTargetDistribution facts_in_class,
List<Integer> split_indices,
List<Fact> split_facts) {
-
- String splitAttr = splitDomain.getName();
- List<?> splitValues = splitDomain.getValues();
- String targetAttr = targetDomain.getName();
- List<?> targetValues = targetDomain.getValues();
- if (Util.DEBUG) {
- System.out.println("entropy.info_cont() attributeToSplit? " + splitAttr);
- int f_i=0;
- for(Fact f: facts) {
- System.out.println("entropy.info_cont() SORTING: "+f_i+" attr "+splitAttr+ " "+ f );
- f_i++;
- }
+
+ if (facts.size() <= 1) {
+ System.out.println("fact.size <=1 returning 0.0....");
+ return 0.0;
}
+ facts_in_class.evaluateMajority();
+ if (facts_in_class.getNum_supported_target_classes()==1) {
+ System.out.println("getNum_supported_target_classes=1 returning 0.0....");
+ return 0.0; //?
+ }
- if (facts.size() <= 1) {
- System.out
- .println("The size of the fact list is 0 oups??? exiting....");
- System.exit(0);
+ if (depth >= 2) {
+ System.out.println("split_indices ==4 returning 0.0....");
+ return 0.0;
}
- if (split_facts.size() < 1) {
- System.out
- .println("The size of the splits is 0 oups??? exiting....");
- System.exit(0);
- }
- /* initialize the distribution */
- Object key0 = Integer.valueOf(0);
- Object key1 = Integer.valueOf(1);
- List<Object> keys = new ArrayList<Object>(2);
- keys.add(key0);
- keys.add(key1);
-
- FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
- facts_at_attribute.setTotal(facts.size());
- facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
- facts_at_attribute.setSumForAttr(key1, facts.size());
-
- double best_sum = -100000.0;
- Object value_to_split = splitValues.get(0);
- int split_index =1, index = 1;
- Iterator<Fact> f_ite = facts.iterator();
- Fact f1 = f_ite.next();
- Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
- if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+split_index+" attr "+splitAttr+ " "+ f1 );
- while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
-
- Fact f2 = f_ite.next();
- if (Util.DEBUG) System.out.print("entropy.info_cont() SEARCHING: "+(index+1)+" attr "+splitAttr+ " "+ f2 );
- Object targetKey = f2.getFieldValue(targetAttr);
-
- // System.out.println("My key: "+ targetKey.toString());
- //for (Object attr_key : attr_values)
-
- /* every time it change the place in the distribution */
- facts_at_attribute.change(key0, targetKey, +1);
- facts_at_attribute.change(key1, targetKey, -1);
-
- /*
- * 2.1 Cut points are points in the sorted list above where the class labels change.
- * Eg. if I had five instances with values for the attribute of interest and labels
- * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
- * two cutpoints of interest: 1.85 and 5 (mid-way between the points
- * where the classes change from A to B or vice versa).
- */
-
- if ( targetComp.compare(f1, f2)!=0) {
- // the cut point
- Number cp_i = (Number) f1.getFieldValue(splitAttr);
- Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
-
- Number cut_point = (Double)(cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
-
- /*
- * 3. Evaluate your favourite disparity measure
- * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
- * and calculate its gain
- */
- double sum = calc_info_attr(facts_at_attribute);
- //System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum +
- if (Util.DEBUG) System.out.println(" **Try "+ sum + " best sum "+best_sum +
- " value ("+ f1.getFieldValue(splitAttr) +"-|"+ value_to_split+"|-"+ f2.getFieldValue(splitAttr)+")");
-
- if (sum > best_sum) {
- best_sum = sum;
- value_to_split = cut_point;
- if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
- split_index = index;
- }
- } else {}
- f1 = f2;
- index++;
- }
- splitDomain.addPseudoValue(value_to_split);
- Util.insert(split_indices, Integer.valueOf(split_index));
- if (Util.DEBUG) {
- System.out.println("entropy.info_contattr(BOK_last) split_indices.size "+split_indices.size());
- for(Integer i : split_indices)
- System.out.println("entropy.info_contattr(FOUNDS) split_indices "+i + " the fact "+facts.get(i));
- System.out.println("entropy.chooseContAttribute(1.5)*********** num of split for "+
- splitAttr+": "+ splitDomain.getValues().size());
- }
- return best_sum;
- }
-
- public static double info_contattr_rec(List<Fact> facts,
- Domain splitDomain, Domain<?> targetDomain,
- FactTargetDistribution facts_in_class,
- List<Integer> split_indices,
- List<Fact> split_facts) {
-
String splitAttr = splitDomain.getName();
List<?> splitValues = splitDomain.getValues();
String targetAttr = targetDomain.getName();
@@ -192,18 +110,6 @@
f_i++;
}
}
-
- if (facts.size() <= 1) {
- System.out
- .println("The size of the fact list is 0 oups??? exiting....");
- System.exit(0);
- }
- if (split_facts.size() < 1) {
- System.out
- .println("The size of the splits is 0 oups??? exiting....");
- System.exit(0);
- }
-
/* initialize the distribution */
Object key0 = Integer.valueOf(0);
Object key1 = Integer.valueOf(1);
@@ -211,26 +117,32 @@
keys.add(key0);
keys.add(key1);
+ Domain trialDomain = splitDomain.clone();
+ trialDomain.addPseudoValue(key0);
+ trialDomain.addPseudoValue(key1);
- FactAttrDistribution facts_at_attribute = new FactAttrDistribution(keys, targetDomain);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(trialDomain, targetDomain);
facts_at_attribute.setTotal(facts.size());
facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
facts_at_attribute.setSumForAttr(key1, facts.size());
- double best_sum = -100000.0;
+ double best_sum = 100000.0;
Object value_to_split = splitValues.get(0);
int split_index =1, index = 1;
FactAttrDistribution best_distribution = null;
- Iterator<Fact> f_ite = facts.iterator();
+ //Iterator<Fact> f_ite = facts.iterator();
+ Iterator<Fact> f_ite = facts.listIterator(begin_index);
+
Fact f1 = f_ite.next();
Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+split_index+" attr "+splitAttr+ " "+ f1 );
- while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
+ while (f_ite.hasNext() && index<=end_index) {/* 2. Look for potential cut-points. */
Fact f2 = f_ite.next();
- if (Util.DEBUG) System.out.print("entropy.info_cont() SEARCHING: "+(index+1)+" attr "+splitAttr+ " "+ f2 );
- Object targetKey = f2.getFieldValue(targetAttr);
+ if (Util.DEBUG) System.out.println("entropy.info_cont() SEARCHING: "+(index+1)+" attr "+splitAttr+ " "+ f2 );
+ Object targetKey = f1.getFieldValue(targetAttr);
+
// System.out.println("My key: "+ targetKey.toString());
//for (Object attr_key : attr_values)
@@ -260,15 +172,16 @@
*/
double sum = calc_info_attr(facts_at_attribute);
//System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum +
- if (Util.DEBUG) System.out.println(" **Try "+ sum + " best sum "+best_sum +
+ if (Util.DEBUG) System.out.println(" **Try @"+ (index+1)+" sum "+sum+" best sum "+best_sum +
" value ("+ f1.getFieldValue(splitAttr) +"-|"+ value_to_split+"|-"+ f2.getFieldValue(splitAttr)+")");
- if (sum > best_sum) {
+ if (sum < best_sum) {
best_sum = sum;
value_to_split = cut_point;
- if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
+ if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: @"+(index+1)+" target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
split_index = index;
- best_distribution = facts_at_attribute.clone();
+ if (best_distribution != null) best_distribution.clear();
+ best_distribution = new FactAttrDistribution(facts_at_attribute);
}
} else {}
f1 = f2;
@@ -283,11 +196,26 @@
List<Integer> split_indices,
List<Fact> split_facts)
*/
- info_contattr_rec(facts.subList(0, split_index),
+ //FactTargetDistribution split0 = best_distribution.getAttrFor(key0);
+ //split0.
+ double sum1 = 0.0;
+ if (split_index>1) {
+ System.out.println("Best distribution: "+ best_distribution);
+ sum1 = info_contattr_rec(facts, begin_index, split_index, depth+1,
splitDomain, targetDomain,
best_distribution.getAttrFor(key0),
split_indices,
split_facts);
+ }
+ //FactTargetDistribution split1 = best_distribution.getAttrFor(key1);
+ double sum2 = 0.0;
+ //if ((facts.size()-split_index)>1 && split1.getNum_supported_target_classes()>1) {
+ sum2 = info_contattr_rec(facts, split_index, end_index, depth+1,
+ splitDomain, targetDomain,
+ best_distribution.getAttrFor(key1),
+ split_indices,
+ split_facts);
+ //}
if (Util.DEBUG) {
@@ -297,7 +225,7 @@
System.out.println("entropy.chooseContAttribute(1.5)*********** num of split for "+
splitAttr+": "+ splitDomain.getValues().size());
}
- return best_sum;
+ return (sum1+ sum2);
}
/*
@@ -357,16 +285,17 @@
return attributeWithGreatestGain;
}
+
public static double info_attr(List<Fact> facts, Domain<?> splitDomain, Domain<?> targetDomain) {
String attributeToSplit = splitDomain.getName();
- List<?> attributeValues = splitDomain.getValues();
+ //List<?> attributeValues = splitDomain.getValues();
String target = targetDomain.getName();
//List<?> targetValues = targetDomain.getValues();
if (Util.DEBUG) System.out.println("What is the attributeToSplit? " + attributeToSplit);
/* initialize the hashtable */
- FactAttrDistribution facts_at_attribute = new FactAttrDistribution(attributeValues, targetDomain);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(splitDomain, targetDomain);
facts_at_attribute.setTotal(facts.size());
for (Fact f : facts) {
@@ -385,21 +314,62 @@
return sum;
}
+ public static double calc_info_contattr(List<Fact> facts,
+ Domain<?> splitDomain, Domain<?> targetDomain,
+ List<Integer> split_indices) {
+
+ List<?> splitValues = splitDomain.getValues();
+ String targetAttr = targetDomain.getName();
+
+ System.out.println("Numof classes in domain "+ splitDomain.getValues().size());
+ System.out.println("Numof splits in domain "+ splitDomain.getIndices().size());
+
+ System.out.println("Numof splits in indices "+ split_indices.size());
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(splitDomain, targetDomain);
+ facts_at_attribute.setTotal(facts.size());
+
+ int index = 0;
+ int split_index = 0;
+ Object attr_key = splitValues.get(split_index);
+ for (Fact f : facts) {
+
+ if (index == split_indices.get(split_index).intValue()+1 ) {
+ attr_key = splitValues.get(split_index+1);
+ split_index++;
+ }
+ Object targetKey = f.getFieldValue(targetAttr);
+ facts_at_attribute.change(attr_key, targetKey, +1);
+
+
+ index++;
+ }
+ double sum = calc_info_attr(facts_at_attribute);
+ return sum;
+
+ }
+
+
/*
* for both
*/
- private static double calc_info_attr( FactAttrDistribution facts_of_attribute) {
+ public static double calc_info_attr( FactAttrDistribution facts_of_attribute) {
Collection<Object> attributeValues = facts_of_attribute.getAttributes();
int fact_size = facts_of_attribute.getTotal();
double sum = 0.0;
+ if (fact_size>0)
for (Object attr : attributeValues) {
int total_num_attr = facts_of_attribute.getSumForAttr(attr);
//double sum_attr = 0.0;
if (total_num_attr > 0) {
- sum += ((double) total_num_attr / (double) fact_size) *
- calc_info(facts_of_attribute.getAttrFor(attr));
+ double prob = (double) total_num_attr / (double) fact_size;
+ System.out.print("{("+total_num_attr +"/"+fact_size +":"+prob +")* [");
+ double info = calc_info(facts_of_attribute.getAttrFor(attr));
+
+ sum += prob * info;
+ System.out.print("]} ");
}
}
+ System.out.println("\n == "+sum);
return sum;
}
@@ -416,6 +386,7 @@
int total_num_facts = facts_in_class.getSum();
Collection<?> targetValues = facts_in_class.getTargetClasses();
double prob, sum = 0;
+ String out =" ";
for (Object key : targetValues) {
int num_in_class = facts_in_class.getVoteFor(key);
// System.out.println("num_in_class : "+ num_in_class + " key "+ key+ " and the total num "+ total_num_facts);
@@ -423,14 +394,172 @@
if (num_in_class > 0) {
prob = (double) num_in_class / (double) total_num_facts;
/* TODO what if it is a sooo small number ???? */
- sum += -1 * prob * Util.log2(prob);
+ out += "("+num_in_class+ "/"+total_num_facts+":"+prob+")" +"*"+ Util.log2(prob) + " + ";
+ sum -= prob * Util.log2(prob);
// System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"where the sum: "+sum);
}
}
+ System.out.print(out +"= " +sum);
return sum;
}
- private static List<Fact> getSplitPoints(List<Fact> facts, String target) {
+ public static Domain<?> chooseContAttribute_binary(DecisionTree dt, List<Fact> facts,
+ FactDistribution facts_in_class, List<String> attrs) {
+
+ double dt_info = calc_info(facts_in_class);
+ double greatestGain = -100000.0;
+ String attributeWithGreatestGain = attrs.get(0);
+ Domain attrDomain = dt.getDomain(attributeWithGreatestGain);
+ Domain bestDomain = null;
+ List<Integer> split_indices = null;
+ Domain<?> targetDomain = dt.getDomain(dt.getTarget());
+ for (String attr : attrs) {
+ System.out.println("Which attribute to try: "+ attr);
+ double gain = 0;
+ if (dt.getDomain(attr).isDiscrete()) {
+ attrDomain = dt.getDomain(attr).clone();
+ for (Object v: dt.getDomain(attr).getValues())
+ attrDomain.addValue(v);
+ gain = dt_info - info_attr(facts, attrDomain, targetDomain);
+
+ } else {
+ /* 1. sort the values */
+ Collections.sort(facts, facts.get(0).getDomain(attr).factComparator());
+ List<Fact> splits = getSplitPoints(facts, dt.getTarget());
+
+ attrDomain = dt.getDomain(attr).clone();
+ attrDomain.addPseudoValue(facts.get(facts.size()-1).getFieldValue(attr));
+ split_indices = new ArrayList<Integer>();
+ gain = dt_info - info_contattr_binary(facts, attrDomain, targetDomain,
+ facts_in_class, split_indices, splits);
+
+ }
+
+ if (gain > greatestGain) {
+ greatestGain = gain;
+ attributeWithGreatestGain = attr;
+ bestDomain = attrDomain;
+ if (!bestDomain.isDiscrete())
+ bestDomain.setIndices(split_indices);
+ }
+ }
+
+ return bestDomain;
+ }
+
+ public static double info_contattr_binary(List<Fact> facts,
+ Domain splitDomain, Domain<?> targetDomain,
+ FactTargetDistribution facts_in_class,
+ List<Integer> split_indices,
+ List<Fact> split_facts) {
+
+ String splitAttr = splitDomain.getName();
+ List<?> splitValues = splitDomain.getValues();
+ String targetAttr = targetDomain.getName();
+ List<?> targetValues = targetDomain.getValues();
+ if (Util.DEBUG) {
+ System.out.println("entropy.info_cont() attributeToSplit? " + splitAttr);
+ int f_i=0;
+ for(Fact f: facts) {
+ System.out.println("entropy.info_cont() SORTING: "+f_i+" attr "+splitAttr+ " "+ f );
+ f_i++;
+ }
+ }
+
+ if (facts.size() <= 1) {
+ System.out
+ .println("The size of the fact list is 0 oups??? exiting....");
+ System.exit(0);
+ }
+ if (split_facts.size() < 1) {
+ System.out
+ .println("The size of the splits is 0 oups??? exiting....");
+ System.exit(0);
+ }
+
+ /* initialize the distribution */
+ Object key0 = Integer.valueOf(0);
+ Object key1 = Integer.valueOf(1);
+ List<Object> keys = new ArrayList<Object>(2);
+ keys.add(key0);
+ keys.add(key1);
+ Domain trialDomain = splitDomain.clone();
+ trialDomain.addPseudoValue(key0);
+ trialDomain.addPseudoValue(key1);
+
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(trialDomain, targetDomain);
+ facts_at_attribute.setTotal(facts.size());
+ facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
+ facts_at_attribute.setSumForAttr(key1, facts.size());
+
+ double best_sum = 100000.0;
+ Object value_to_split = splitValues.get(0);
+ int split_index =1, index = 1;
+ Iterator<Fact> f_ite = facts.iterator();
+ Fact f1 = f_ite.next();
+ Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
+ if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+split_index+" attr "+splitAttr+ " "+ f1 );
+ while (f_ite.hasNext()) {/* 2. Look for potential cut-points. */
+
+ Fact f2 = f_ite.next();
+ Object targetKey = f1.getFieldValue(targetAttr);
+ if (Util.DEBUG) System.out.println("entropy.info_cont() SEARCHING: "+(index+1)+" "+ f2 );
+
+ // System.out.println("My key: "+ targetKey.toString());
+ //for (Object attr_key : attr_values)
+
+ /* every time it change the place in the distribution */
+ facts_at_attribute.change(key0, targetKey, +1);
+ facts_at_attribute.change(key1, targetKey, -1);
+
+ /*
+ * 2.1 Cut points are points in the sorted list above where the class labels change.
+ * Eg. if I had five instances with values for the attribute of interest and labels
+ * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
+ * two cutpoints of interest: 1.85 and 5 (mid-way between the points
+ * where the classes change from A to B or vice versa).
+ */
+
+ if ( targetComp.compare(f1, f2)!=0) {
+ // the cut point
+ Number cp_i = (Number) f1.getFieldValue(splitAttr);
+ Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
+
+ Number cut_point = (Double)(cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
+
+ /*
+ * 3. Evaluate your favourite disparity measure
+ * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
+ * and calculate its gain
+ */
+ double sum = calc_info_attr(facts_at_attribute);
+ //System.out.println("**entropy.info_contattr() FOUND: "+ sum + " best sum "+best_sum +
+ if (Util.DEBUG) System.out.println(" **Try @"+ (index+1)+" sum:"+ sum + " best sum"+best_sum +
+ " value ("+ f1.getFieldValue(splitAttr) +"-|"+ value_to_split+"|-"+ f2.getFieldValue(splitAttr)+")");
+
+ if (sum < best_sum) {
+ best_sum = sum;
+ value_to_split = cut_point;
+ if (Util.DEBUG) System.out.println(Util.ntimes("?", 10)+"** FOUND: @"+(index+1)+" target ("+ f1.getFieldValue(targetAttr) +"-|T|-"+ f2.getFieldValue(targetAttr)+")");
+ split_index = index;
+ }
+ } else {}
+ f1 = f2;
+ index++;
+ }
+ splitDomain.addPseudoValue(value_to_split);
+ Util.insert(split_indices, Integer.valueOf(split_index));
+ if (Util.DEBUG) {
+ System.out.println("entropy.info_contattr(BOK_last) split_indices.size "+split_indices.size());
+ for(Integer i : split_indices)
+ System.out.println("entropy.info_contattr(FOUNDS) split_indices "+i + " the fact "+facts.get(i));
+ System.out.println("entropy.chooseContAttribute(1.5)*********** num of split for "+
+ splitAttr+": "+ splitDomain.getValues().size());
+ }
+ return best_sum;
+ }
+
+ public static List<Fact> getSplitPoints(List<Fact> facts, String target) {
List<Fact> splits = new ArrayList<Fact>();
Iterator<Fact> it_f = facts.iterator();
Fact f1 = it_f.next();
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -132,5 +132,9 @@
public List<Integer> getIndices() {
return null; //indices;
}
+
+ public void addIndex(int index) {
+ // TODO Auto-generated method stub
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -32,7 +32,9 @@
public Domain<T> clone();
void setIndices(List<Integer> split_indices);
+
List<Integer> getIndices();
+ void addIndex(int index);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactAttrDistribution.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -7,29 +7,50 @@
import dt.tools.Util;
public class FactAttrDistribution {
- String attr_sum = Util.sum();
- Hashtable<Object, FactTargetDistribution> facts_at_attr;
+ private String attr_sum = Util.sum();
+ private Domain<?> domain; // domain of the attr
+ private Hashtable<Object, FactTargetDistribution> facts_at_attr;
private int total_num;
- public FactAttrDistribution(List<?> attributeValues, Domain<?> targetDomain) {
+ public FactAttrDistribution(Domain<?> attributeDomain, Domain<?> targetDomain) {
+ this.domain = attributeDomain;
+ List<?> attributeValues = this.domain.getValues();
facts_at_attr = new Hashtable<Object, FactTargetDistribution>(attributeValues.size());
for (Object attr : attributeValues) {
facts_at_attr.put(attr, new FactTargetDistribution(targetDomain));
-// for (Object t : targetDomain.getValues()) {
-// facts_at_attr.get(attr).put(t, 0);
-// }
-// facts_at_attr.get(attr).put(attr_sum, 0);
}
}
- public FactAttrDistribution clone() {
- return this.clone();
+ public FactAttrDistribution(FactAttrDistribution copy) {
+ this.domain = copy.getDomain();
+ this.facts_at_attr = new Hashtable<Object, FactTargetDistribution>(copy.getNumAttributes());
+
+ for (Object attr : copy.getAttributes()) {
+ FactTargetDistribution attr_x = new FactTargetDistribution(copy.getAttrFor(attr));
+ facts_at_attr.put(attr, attr_x);
+ }
+ this.total_num = copy.getTotal();
+
}
+// public FactAttrDistribution clone() {
+// FactAttrDistribution temp = new FactAttrDistribution(this);
+// return this.clone();
+// }
+
+ private Domain<?> getDomain() {
+ return this.domain;
+ }
+
+ public void clear() {
+ this.facts_at_attr.clear();
+ // all for each?
+ }
+
public void setTotal(int size) {
this.total_num = size;
}
@@ -49,15 +70,11 @@
facts_at_attr.get(attr_value).setSum(total);
}
-// public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
-// for (Object target: targetDist.keySet())
-// facts_at_attr.get(attr_value).put(target,targetDist.get(target));
-// }
public void setTargetDistForAttr(Object attr_value, FactTargetDistribution targetDist) {
//facts_at_attr.put(attr_value, targetDist);
- /* TODO should i make a close */
+ /* TODO should i make a clone */
FactTargetDistribution old = facts_at_attr.get(attr_value);
old.setDistribution(targetDist);
@@ -67,20 +84,25 @@
facts_at_attr.get(attrValue).change(targetValue, i);
facts_at_attr.get(attrValue).change(attr_sum, i);
-/* int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
- num_1 += i;
- facts_at_attr.get(attrValue).put(targetValue, num_1);
- int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
- total_num_1 += i;
- facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
- */
-
}
public Collection<Object> getAttributes() {
return facts_at_attr.keySet();
}
+ public int getNumAttributes() {
+ return facts_at_attr.keySet().size();
+ }
+
+ public String toString() {
+ String out = "FAD: attr: "+domain.getName()+" total num: "+ this.getTotal() + "\n" ;
+ for (Object attr : this.getAttributes()) {
+ FactTargetDistribution ftd = facts_at_attr.get(attr);
+ out += ftd ;
+ }
+ return out;
+ }
+
}
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -0,0 +1,78 @@
+package dt.memory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Hashtable;
+import java.util.List;
+
+import dt.tools.Util;
+
+/* adding to the FactTargetDistribution
+ * it keeps the facts themselves which have that target value
+ */
+public class FactDistribution extends FactTargetDistribution{
+
+ private String attr_sum = Util.sum();
+ private Hashtable<Object, List<Fact>> facts_at_target;
+
+ /*
+ private int num_supported_target_classes;
+ private Object the_winner_target_class;
+ */
+ public FactDistribution(Domain<?> targetDomain) {
+
+ super(targetDomain);
+
+ //num_supported_target_classes = 0;
+ List<?> targetValues = targetDomain.getValues();
+
+ facts_at_target = new Hashtable<Object, List<Fact>>(targetValues.size());
+ for (Object t : targetValues) {
+
+ facts_at_target.put(t, new ArrayList<Fact>());
+ }
+
+ }
+
+ public void calculateDistribution(List<Fact> facts){
+ int total_num_facts = 0;
+ String target = super.getTargetDomain().getName();
+ for (Fact f : facts) {
+ total_num_facts++;
+ Object target_key = f.getFieldValue(target);
+ // System.out.println("My key: "+ key.toString());
+ super.change(target_key, 1); // add one for vote for the target value : target_key
+ facts_at_target.get(target_key).add(f);
+
+ }
+ super.change(attr_sum, total_num_facts);
+
+ }
+
+ public Collection<Object> getTargetClasses() {
+ return facts_at_target.keySet();
+ }
+
+ public List<Fact> getSupportersFor(Object value) {
+ return facts_at_target.get(value);
+ }
+
+/*
+ public int getNum_supported_target_classes() {
+ return num_supported_target_classes;
+ }
+
+ public void setNum_supperted_target_classes(int num_supperted_target_classes) {
+ this.num_supported_target_classes = num_supperted_target_classes;
+ }
+
+ public Object getThe_winner_target_class() {
+ return the_winner_target_class;
+ }
+
+ public void setThe_winner_target_class(Object the_winner_target_class) {
+ this.the_winner_target_class = the_winner_target_class;
+ }
+ */
+
+}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactTargetDistribution.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -13,6 +13,11 @@
protected Domain<?> targetDomain;
private Hashtable<Object, Integer> num_at_target;
+ //
+ private int num_supported_target_classes;
+ private Object the_winner_target_class;
+ //
+
public FactTargetDistribution(Domain<?> targetDomain) {
this.targetDomain = targetDomain;
@@ -22,9 +27,20 @@
num_at_target.put(t, 0);
}
num_at_target.put(attr_sum, 0);
+ num_supported_target_classes = 0;
}
+ public FactTargetDistribution(FactTargetDistribution copy_dist) {
+
+ this.targetDomain = copy_dist.getTargetDomain();
+ List<?> targetValues = targetDomain.getValues();
+ this.num_at_target = new Hashtable<Object, Integer>(targetValues.size() + 1);
+ this.setDistribution(copy_dist);
+ this.num_supported_target_classes = copy_dist.getNum_supported_target_classes();
+
+ }
+
public Collection<?> getTargetClasses() {
return targetDomain.getValues();
}
@@ -47,18 +63,64 @@
public void setTargetDomain(Domain<?> targetDomain) {
this.targetDomain = targetDomain.clone();
}
+ public int getNum_supported_target_classes() {
+ return num_supported_target_classes;
+ }
+
+ public void setNum_supperted_target_classes(int num_supperted_target_classes) {
+ this.num_supported_target_classes = num_supperted_target_classes;
+ }
+
+ public Object getThe_winner_target_class() {
+ return the_winner_target_class;
+ }
+
+ public void setThe_winner_target_class(Object the_winner_target_class) {
+ this.the_winner_target_class = the_winner_target_class;
+ }
public void change(Object targetValue, int i) {
int num_1 = num_at_target.get(targetValue).intValue();
num_1 += i;
num_at_target.put(targetValue, num_1);
}
+
+ public void evaluateMajority() {
+
+ List<?> targetValues = targetDomain.getValues();
+ int winner_vote = 0;
+ int num_supporters = 0;
+
+ Object winner = null;
+ for (Object key : targetValues) {
+ int num_in_class = getVoteFor(key);
+ if (num_in_class > 0)
+ num_supporters++;
+ if (num_in_class > winner_vote) {
+ winner_vote = num_in_class;
+ winner = key;
+ }
+ }
+ this.setNum_supperted_target_classes(num_supporters);
+ this.setThe_winner_target_class(winner);
+
+ }
+
public void setDistribution(FactTargetDistribution targetDist) {
for (Object targetValue: targetDomain.getValues()) {
num_at_target.put(targetValue, targetDist.getVoteFor(targetValue));
}
+ num_at_target.put(attr_sum, targetDist.getSum());
+ }
+
+ public String toString() {
+ String out = "FTD: target:"+ this.targetDomain.getName()+ " total: "+ this.getSum() + " dist:";
+ for (Object value: this.getTargetClasses())
+ out += this.getVoteFor(value) +" @"+value+ ", ";
+ out +="\n";
+ return out;
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -136,5 +136,9 @@
public List<Integer> getIndices() {
return null; //indices;
}
+
+ public void addIndex(int index) {
+ // TODO Auto-generated method stub
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -177,6 +177,10 @@
indices.clear();
indices.addAll(split_indices);
}
+
+ public void addIndex(int index) {
+ indices.add(index);
+ }
public List<Integer> getIndices() {
return indices;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FactProcessor.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -9,7 +9,6 @@
import dt.memory.Domain;
import dt.memory.Fact;
import dt.memory.FactDistribution;
-import dt.memory.FactTargetDistribution;
public class FactProcessor {
@@ -19,9 +18,11 @@
return FactProcessor.splitFacts_disc(facts, choosenDomain);
} else {
Collections.sort(facts, choosenDomain.factComparator()); /* hack*/
- return FactProcessor.splitFacts_cont(facts, choosenDomain);
+ return FactProcessor.splitFacts_cont_opt(facts, choosenDomain);
}
}
+
+
public static Hashtable<Object, List<Fact>> splitFacts_disc(List<Fact> facts, Domain<?> choosenDomain) {
String attributeName = choosenDomain.getName();
List<?> attributeValues = choosenDomain.getValues();
@@ -34,50 +35,89 @@
}
return factLists;
}
-
+ public static Hashtable<Object, List<Fact>> splitFacts_cont(List<Fact> facts, Domain<?> attributeDomain) {
+
+ List<?> splitValues = attributeDomain.getValues();
+ List<Integer> splitIndices = attributeDomain.getIndices();
+
+ System.out.println("Numof classes in domain "+ attributeDomain.getValues().size());
+ System.out.println("Numof splits in domain "+ attributeDomain.getIndices().size());
+
+ System.out.println("Numof splits in indices "+ splitValues.size());
+ Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(splitValues.size());
+ for (Object v: attributeDomain.getValues()) {
+ factLists.put(v, new ArrayList<Fact>());
+ }
+
+
+ int index = 0;
+ int split_index = 0;
+ Object attr_key = splitValues.get(split_index);
+ List<Fact> subList = factLists.get(attr_key);
+ for (Fact f : facts) {
+
+ if (index == splitIndices.get(split_index).intValue()+1 ) {
+ attr_key = splitValues.get(split_index+1);
+ subList = factLists.get(attr_key);
+ split_index++;
+ }
+ subList.add(f);
+ index++;
+ }
+
+ return factLists;
+
+ }
/* it must work */
- private static Hashtable<Object, List<Fact>> splitFacts_cont(
+ private static Hashtable<Object, List<Fact>> splitFacts_cont_opt(
List<Fact> facts, Domain<?> attributeDomain) {
String attributeName = attributeDomain.getName();
if (Util.DEBUG) System.out.println("FactProcessor.splitFacts_cont() attr_split "+ attributeName);
- List<?> categorization = attributeDomain.getValues();
- List<Integer> split_indices = attributeDomain.getIndices();
+ List<?> splitValues = attributeDomain.getValues();
+ List<Integer> splitIndices = attributeDomain.getIndices();
if (Util.DEBUG) {
- System.out.println("FactProcessor.splitFacts_cont() haniymis benim repsentativelerim: "+ categorization.size() + " and the split points "+ split_indices.size());
+ System.out.println("FactProcessor.splitFacts_cont() haniymis benim repsentativelerim: "+ splitValues.size() + " and the split points "+ splitIndices.size());
System.out.println("FactProcessor.splitFacts_cont() before splitting "+ facts.size());
- int split_i =0;
- for(int i=0; i<facts.size(); i++) {
- if (split_i<split_indices.size() && split_indices.get(split_i).intValue()== i) {
- System.out.println("PRINT*: FactProcessor.splitFacts_cont() will split at "+i + " the fact "+facts.get(i));
- split_i ++;
+
+ int index = 0;
+ int split_index = 0;
+ Object attr_key = splitValues.get(split_index);
+ for (Fact f : facts) {
+
+ if (index == splitIndices.get(split_index).intValue()+1 ) {
+ System.out.print("PRINT* (");
+ attr_key = splitValues.get(split_index+1);
+ split_index++;
} else {
- System.out.println("PRINT: FactProcessor.splitFacts_cont() at "+i + " the fact "+facts.get(i));
+ System.out.print("PRINT (");
}
+ System.out.println(split_index+"): fact "+f);
+ index++;
}
+
}
- Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(categorization.size());
+ Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(splitValues.size());
for (Object v: attributeDomain.getValues()) {
factLists.put(v, new ArrayList<Fact>());
}
//Comparator<Fact> cont_comp = attributeDomain.factComparator();
- Iterator<Integer> splits_it = split_indices.iterator();
+ Iterator<Integer> splits_it = splitIndices.iterator();
int start_point = 0;
int index = 0;
-
- while (splits_it.hasNext() || index < attributeDomain.getValues().size()) {
- int integer_index;
- if (splits_it.hasNext())
- integer_index = splits_it.next().intValue();
- else
- integer_index = facts.size();
+ while (splits_it.hasNext()) {// || index < attributeDomain.getValues().size()
+ int integer_index = splits_it.next().intValue();
+// if (splits_it.hasNext())
+// integer_index = splits_it.next().intValue();
+// else
+// integer_index = facts.size();
- Object category = attributeDomain.getValues().get(index);
+ Object category = splitValues.get(index);
//System.out.println("FactProcessor.splitFacts_cont() new category: "+ category);
Fact pseudo = new Fact();
try {
@@ -86,8 +126,8 @@
System.out.println("FactProcessor.splitFacts_cont() new category: "+ category );
System.out.println(" ("+start_point+","+integer_index+")");
}
- factLists.put(category, facts.subList(start_point, integer_index));
- start_point = integer_index;
+ factLists.put(category, facts.subList(start_point, integer_index+1));
+ start_point = integer_index+1;
} catch (Exception e) {
// TODO Auto-generated catch block
@@ -105,19 +145,20 @@
List<Fact> unclassified_facts, FactDistribution stats) {
Object winner = stats.getThe_winner_target_class();
- System.out.println(Util.ntimes("DANIEL", 2)+ " lets get unclassified daniel winner "+winner +" num of sup " +stats.getVoteFor(winner));
+ //System.out.println(Util.ntimes("DANIEL", 2)+ " lets get unclassified daniel winner "+winner +" num of sup " +stats.getVoteFor(winner));
for (Object looser: stats.getTargetClasses()) {
int num_supp = stats.getVoteFor(looser);
if ((num_supp > 0) && !winner.equals(looser)) {
- System.out.println(Util.ntimes("DANIEL", 2)+ " one looser ? "+looser + " num of sup="+num_supp);
+ //System.out.println(Util.ntimes("DANIEL", 2)+ " one looser ? "+looser + " num of sup="+num_supp);
//System.out.println(" the num of supporters = "+ stats.getVoteFor(looser));
//System.out.println(" but the guys "+ stats.getSupportersFor(looser));
//System.out.println("How many bok: "+stats.getSupportersFor(looser).size());
unclassified_facts.addAll(stats.getSupportersFor(looser));
- } else
- System.out.println(Util.ntimes("DANIEL", 5)+ "how many times matching?? not a looser "+ looser );
+ } else {
+ //System.out.println(Util.ntimes("DANIEL", 5)+ "how many times matching?? not a looser "+ looser );
+ }
}
@SuppressWarnings("unused")
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -66,4 +66,34 @@
}
+
+ public static List<Object> processFileExmC45(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator, int max_rules) {
+
+ try {
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
+ C45TreeBuilder bocuk = new C45TreeBuilder();
+
+ long dt = System.currentTimeMillis();
+ String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
+
+ List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
+
+ DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, workingAttributes);
+ dt = System.currentTimeMillis() - dt;
+ System.out.println("Time" + dt + "\n" + bocuksTree);
+
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed(), max_rules);
+ boolean sort_via_rank = true;
+ my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
+
+ return obj_read;
+
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ return null;
+
+
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -26,6 +26,7 @@
private boolean ONLY_ACTIVE = true;
private int num_facts;
//private RuleComparator rule_comp = new RuleComparator();
+ private int max_num_rules;
public RulePrinter(int num_facts) {
@@ -39,6 +40,18 @@
this.num_facts = num_facts;
}
+
+ public RulePrinter(int num_facts, int max_num_rules) {
+ ruleText = new ArrayList<String>();
+ //rule_list = new ArrayList<ArrayList<NodeValue>>();
+ rules = new ArrayList<Rule>();
+
+ /* most important */
+ nodes = new Stack<NodeValue>();
+
+ this.num_facts = num_facts;
+ this.max_num_rules = max_num_rules;
+ }
public int getNum_facts() {
return num_facts;
}
@@ -90,7 +103,9 @@
write("\n", true, outputFile);
}
}
- total_num_facts += rule.getPopularity();
+ total_num_facts += rule.getPopularity();
+ if (i == getMax_num_rules())
+ break;
}
if (outputFile!=null) {
write("//THE END: Total number of facts correctly classified= "+ total_num_facts + " over "+ getNum_facts() , true, outputFile);
@@ -103,7 +118,7 @@
}
private void dfs(TreeNode my_node) {
- System.out.println("How many guys there of "+my_node.getDomain().getName() +" : "+my_node.getDomain().getValues().size());
+ //System.out.println("How many guys there of "+my_node.getDomain().getName() +" : "+my_node.getDomain().getValues().size());
NodeValue node_value = new NodeValue(my_node);
nodes.push(node_value);
@@ -217,6 +232,16 @@
}
}
}
+
+
+ public int getMax_num_rules() {
+ return max_num_rules;
+ }
+
+
+ public void setMax_num_rules(int max_num_rules) {
+ this.max_num_rules = max_num_rules;
+ }
}
class Rule {
@@ -369,7 +394,7 @@
return node.getDomain() + " == "+ value;
else {
int size = node.getDomain().getValues().size();
- System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
+ //System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
if (node.getDomain().getValues().lastIndexOf(nodeValue) == size-1)
return node.getDomain() + " > "+ node.getDomain().getValues().get(size-2);
else
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -7,7 +7,7 @@
public class Util {
- public static boolean DEBUG = false;
+ public static boolean DEBUG = true;
public static String ntimes(String s,int n){
StringBuffer buf = new StringBuffer();
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-04 00:21:45 UTC (rev 19408)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-04 01:21:08 UTC (rev 19409)
@@ -1,10 +1,14 @@
package test;
+import java.util.List;
+
import dt.DecisionTree;
+import dt.builder.C45TreeBuilder;
import dt.builder.IDTreeBuilder;
import dt.memory.FactSetFactory;
import dt.memory.WorkingMemory;
+import dt.tools.ObjectReader;
import dt.tools.RulePrinter;
public class BocukFileExample {
@@ -44,7 +48,7 @@
my_printer.printer(bocuksTree, null, null, sort_via_rank);
}
}
+
-
}
More information about the jboss-svn-commits
mailing list