[jboss-svn-commits] JBL Code SVN: r19329 - in labs/jbossrules/contrib/machinelearning/decisiontree/src: dt/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Mon Mar 31 00:09:10 EDT 2008
Author: gizil
Date: 2008-03-31 00:09:10 -0400 (Mon, 31 Mar 2008)
New Revision: 19329
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java
Log:
hackish working discretization
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,13 +1,8 @@
package dt;
import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
import java.util.Hashtable;
-import java.util.Iterator;
import java.util.List;
-import java.util.Set;
import dt.memory.Domain;
import dt.memory.Fact;
@@ -37,10 +32,9 @@
this.attrsToClassify = new ArrayList<String>();
}
- private Object getConsensus(List<Fact> facts) {
+ public Object getMajority(List<Fact> facts) {
List<?> targetValues = getPossibleValues(this.target);
Hashtable<Object, Integer> facts_in_class = getStatistics(facts, target);
- // , targetValues
int winner_vote = 0;
Object winner = null;
@@ -55,397 +49,6 @@
return winner;
}
- // *OPT* public double calculateGain(List<FactSet> facts, String
- // attributeName) {
- // I dont use
- public double calculateGain(List<Fact> facts,
- Hashtable<Object, Integer> facts_in_class, String attributeName) {
-
- return getInformation(facts_in_class, facts.size())
- - getGain(facts, attributeName);
- }
-
- // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
- // {
- public double getGain(List<Fact> facts, String attributeToSplit) {
- System.out.println("What is the attributeToSplit? " + attributeToSplit);
- List<?> attributeValues = getPossibleValues(attributeToSplit);
-
- String attr_sum = "sum";
-
- List<?> targetValues = getPossibleValues(getTarget());
- // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
- // Integer>(targetValues.size());
-
- /* initialize the hashtable */
- Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute = new Hashtable<Object, Hashtable<Object, Integer>>(
- attributeValues.size());
- for (Object attr : attributeValues) {
- facts_of_attribute.put(attr, new Hashtable<Object, Integer>(
- targetValues.size() + 1));
- for (Object t : targetValues) {
- facts_of_attribute.get(attr).put(t, 0);
- }
- facts_of_attribute.get(attr).put(attr_sum, 0);
- }
-
- int total_num_facts = 0;
- // *OPT* for (FactSet fs: facts) {
- // *OPT* for (Fact f: fs.getFacts()) {
- for (Fact f : facts) {
- total_num_facts++;
- Object targetKey = f.getFieldValue(target);
- // System.out.println("My key: "+ targetKey.toString());
-
- Object attr_key = f.getFieldValue(attributeToSplit);
- int num = facts_of_attribute.get(attr_key).get(targetKey)
- .intValue();
- num++;
- facts_of_attribute.get(attr_key).put(targetKey, num);
-
- int total_num = facts_of_attribute.get(attr_key).get(attr_sum)
- .intValue();
- total_num++;
- facts_of_attribute.get(attr_key).put(attr_sum, total_num);
-
- // System.out.println("getGain of "+attributeToSplit+
- // ": total_num "+ facts_of_attribute.get(attr_key).get(attr_sum) +
- // " and "+facts_of_attribute.get(attr_key).get(targetKey) +
- // " at attr=" + attr_key + " of t:"+targetKey);
- }
- FACTS_READ += facts.size();
- // *OPT* }
- // *OPT* }
- double sum = getAttrInformation(facts_of_attribute, total_num_facts);
-// for (Object attr : attributeValues) {
-// int total_num_attr = facts_of_attribute.get(attr).get(attr_sum)
-// .intValue();
-//
-// double sum_attr = 0.0;
-// if (total_num_attr > 0)
-// for (Object t : targetValues) {
-// int num_attr_target = facts_of_attribute.get(attr).get(t)
-// .intValue();
-//
-// double prob = (double) num_attr_target / total_num_attr;
-// // System.out.println("prob "+ prob);
-// sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util
-// .log2(prob));
-// }
-// sum += ((double) total_num_attr / (double) total_num_facts)
-// * sum_attr;
-// }
- return sum;
- }
-
- /*
- * GLOBAL DISCRETIZATION a a b a b b b b b (target) 1 2 3 4 5 6 7 8 9 (attr
- * c) 0 0 0 0 1 1 1 1 1 "<5", ">=5" "true" "false"
- */
- /*
- * The algorithm is basically (per attribute):
- *
- * 1. Sort the instances on the attribute of interest
- *
- * 2. Look for potential cut-points. Cut points are points in the sorted
- * list above where the class labels change. Eg. if I had five instances
- * with values for the attribute of interest and labels (1.0,A), (1.4,A),
- * (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only two cutpoints
- * of interest: 1.85 and 5 (mid-way between the points where the classes
- * change from A to B or vice versa).
- *
- * 3. Evaluate your favourite disparity measure (info gain, gain ratio, gini
- * coefficient, chi-squared test) on each of the cutpoints, and choose the
- * one with the maximum value (I think Fayyad and Irani used info gain).
- *
- * 4. Repeat recursively in both subsets (the ones less than and greater
- * than the cutpoint) until either (a) the subset is pure i.e. only contains
- * instances of a single class or (b) some stopping criterion is reached. I
- * can't remember what stopping criteria they used.
- */
-
- // *OPT* public double getGain(List<FactSet> facts, String attributeToSplit)
- public double getContinuousGain(List<Fact> facts,
- List<Integer> split_facts, int begin_index, int end_index,
- Hashtable<Object, Integer> facts_in_class, String attributeToSplit) {
-
- System.out.println("What is the attributeToSplit? " + attributeToSplit);
-
- if (facts.size() <= 1) {
- System.out
- .println("The size of the fact list is 0 oups??? exiting....");
- System.exit(0);
- }
- if (split_facts.size() < 1) {
- System.out
- .println("The size of the splits is 0 oups??? exiting....");
- System.exit(0);
- }
-
- String targetAttr = getTarget();
- List<?> targetValues = getPossibleValues(getTarget());
- List<?> boundaries = getPossibleValues(attributeToSplit);
-
- // Fact split_point = facts.get(facts.size() / 2);
- // a b a a b
- // 1 2 3 4 5
- // 1.5
- // 2.5
- // 3.5
- // 0.00001 0.00002 1 100
- // 0.000015
-
- // < 50 >
- // 25 75
- // HashTable<Boolean>
-
- String attr_sum = Util.getSum();
-
-
-
- /* initialize the hashtable */
- Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
- new Hashtable<Object, Hashtable<Object, Integer>>(Util.getDividingSize());
- // attr_0 bhas nothing everything inside attr_1
- Object cut_point; //attr_0
- Object last_poit = facts.get(facts.size()-1).getFieldValue(attributeToSplit);
- for (int i = 0; i < 2; i++) {
- facts_of_attribute.put(Integer.valueOf(i),
- new Hashtable<Object, Integer>(targetValues.size() + 1));
- //Hashtable<Object, Integer> facts_in_class
- if (i == 1) {
- for (Object t : targetValues) {
- facts_of_attribute.get(Integer.valueOf(i)).put(t,
- facts_in_class.get(t));
- }
- facts_of_attribute.get(Integer.valueOf(i)).put(attr_sum,
- facts.size());
- } else {
- for (Object t : targetValues) {
- facts_of_attribute.get(Integer.valueOf(i)).put(t, 0);
- }
- facts_of_attribute.get(Integer.valueOf(i)).put(attr_sum, 0);
- }
- }
-
- /*
- * 2. Look for potential cut-points.
- */
-
- int split_index = 1;
- int last_index = facts.size();
- Iterator<Fact> f_ite = facts.iterator();
- Fact f1 = f_ite.next();
- while (f_ite.hasNext()) {
-
- Fact f2 = f_ite.next();
-
- // everytime it is not a split change the place in the distribution
-
- Object targetKey = f2.getFieldValue(target);
-
- // System.out.println("My key: "+ targetKey.toString());
-
- //for (Object attr_key : attr_values)
-
- Object attr_key_1 = Integer.valueOf(0);
- int num_1 = facts_of_attribute.get(attr_key_1).get(targetKey).intValue();
- num_1++;
- facts_of_attribute.get(attr_key_1).put(targetKey, num_1);
-
- int total_num_1 = facts_of_attribute.get(attr_key_1).get(attr_sum).intValue();
- total_num_1++;
- facts_of_attribute.get(attr_key_1).put(attr_sum, total_num_1);
-
- Object attr_key_2= Integer.valueOf(1);
- int num_2 = facts_of_attribute.get(attr_key_2).get(targetKey).intValue();
- num_2--;
- facts_of_attribute.get(attr_key_2).put(targetKey, num_2);
-
- int total_num_2 = facts_of_attribute.get(attr_key_2).get(attr_sum).intValue();
- total_num_2++;
- facts_of_attribute.get(attr_key_2).put(attr_sum, total_num_2);
-
- /*
- * 2.1 Cut points are points in the sorted list above where the class labels change.
- * Eg. if I had five instances with values for the attribute of interest and labels
- * (1.0,A), (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
- * two cutpoints of interest: 1.85 and 5 (mid-way between the points
- * where the classes change from A to B or vice versa).
- */
- if (f1.getFieldValue(targetAttr) != f2.getFieldValue(targetAttr)) {
- // the cut point
- Number cp_i = (Number) f1.getFieldValue(attributeToSplit);
- Number cp_i_next = (Number) f2.getFieldValue(attributeToSplit);
-
- cut_point = (cp_i.doubleValue() + cp_i_next
- .doubleValue()) / 2;
-
- /*
- * 3. Evaluate your favourite disparity measure
- * (info gain, gain ratio, gini coefficient, chi-squared test) on the cut point
- * and calculate its gain
- */
- double sum = getAttrInformation(facts_of_attribute, facts.size());
-//
-// double sum = 0.0;
-// // for (Object attr : attributeValues) {
-// for (int i = 0; i < 2; i++) {
-//
-// int total_num_attr = facts_of_attribute.get(Integer.valueOf(i)).get(attr_sum).intValue();
-//
-// double sum_attr = 0.0;
-// if (total_num_attr > 0)
-// for (Object t : targetValues) {
-// int num_attr_target = facts_of_attribute.get(Integer.valueOf(i)).get(t).intValue();
-//
-// double prob = (double) num_attr_target / total_num_attr;
-// // System.out.println("prob "+ prob);
-// sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util.log2(prob));
-// }
-// sum += ((double) total_num_attr / (double) facts.size())* sum_attr;
-// }
-
-
- } else {}
-
-// getContinuousGain(facts, split_facts.subList(0,
-// split_index+1), 0, split_index+1,
-// facts_in_class1, attributeToSplit);
-//
-// getContinuousGain(facts, split_facts.subList(split_index+1,
-// last_index), split_index+1, last_index,
-// facts_in_class2, attributeToSplit);
-
- f1 = f2;
- split_index ++;
- }
-
- return 1.0;
- }
-
- public double getContinuousGain_(List<Fact> facts,
- List<Integer> split_facts, int begin_index, int end_index,
- Hashtable<Object, Integer> facts_in_class, String attributeToSplit) {
- System.out.println("What is the attributeToSplit? " + attributeToSplit);
-
- if (facts.size() <= 1) {
- System.out
- .println("The size of the fact list is 0 oups??? exiting....");
- System.exit(0);
- }
- if (split_facts.size() < 1) {
- System.out
- .println("The size of the splits is 0 oups??? exiting....");
- System.exit(0);
- }
-
- String targetAttr = getTarget();
- List<?> boundaries = getPossibleValues(attributeToSplit);
-
- // Fact split_point = facts.get(facts.size() / 2);
- // a b a a b
- // 1 2 3 4 5
- // 1.5
- // 2.5
- // 3.5
- // 0.00001 0.00002 1 100
- // 0.000015
-
- // < 50 >
- // 25 75
- // HashTable<Boolean>
-
- String attr_sum = "sum";
-
- /*
- * 2. Look for potential cut-points. Cut points are points in the sorted
- * list above where the class labels change. Eg. if I had five instances
- * with values for the attribute of interest and labels (1.0,A),
- * (1.4,A), (1.7, A), (2.0,B), (3.0, B), (7.0, A), then there are only
- * two cutpoints of interest: 1.85 and 5 (mid-way between the points
- * where the classes change from A to B or vice versa).
- */
-
- /* initialize the hashtable */
- // Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute =
- // new Hashtable<Object, Hashtable<Object,
- // Integer>>(Util.getDividingSize());
- // for (Object attr : attributeValues) {
- // facts_of_attribute.put(attr, new Hashtable<Object, Integer>(
- // targetValues.size() + 1));
- // for (Object t : targetValues) {
- // facts_of_attribute.get(attr).put(t, 0);
- // }
- // facts_of_attribute.get(attr).put(attr_sum, 0);
- // }
- //
- int split_index = 0;
- Iterator<Integer> split_ite = split_facts.iterator();
- int f1_index = split_ite.next().intValue();
- Fact f1 = facts.get(f1_index);
- while (split_ite.hasNext()) {
- int f2_index = f1_index + 1;
- Fact f2 = facts.get(f2_index);
-
- if (f1.getFieldValue(targetAttr) == f2.getFieldValue(targetAttr)) {
- // the cut point
- System.out
- .println("Bok i have splited what the fuck is happening f1:"
- + f1 + " f2:" + f2);
- System.exit(0);
-
- }
- Number cp_i = (Number) f1.getFieldValue(attributeToSplit);
- Number cp_i_next = (Number) f2.getFieldValue(attributeToSplit);
-
- Object cut_point = (cp_i.doubleValue() + cp_i_next.doubleValue()) / 2;
- // calculate the gain of the cut point
-
- /*
- * 3. Evaluate your favourite disparity measure (info gain, gain
- * ratio, gini coefficient, chi-squared test) on each of the
- * cutpoints, and choose the one with the maximum value (I think
- * Fayyad and Irani used info gain).
- */
- // double sum = 0.0;
- // //for (Object attr : attributeValues) {
- // for (int i = 1; i<2; i++) {
- //
- // int total_num_attr =
- // facts_of_attribute.get(attr).get(attr_sum).intValue();
- //
- // double sum_attr = 0.0;
- // if (total_num_attr > 0)
- // for (Object t : targetValues) {
- // int num_attr_target =
- // facts_of_attribute.get(attr).get(t).intValue();
- //
- // double prob = (double) num_attr_target/ total_num_attr;
- // // System.out.println("prob "+ prob);
- // sum_attr += (prob == 0.0) ? 0.0 : (-1 * prob * Util.log2(prob));
- // }
- // sum += ((double) total_num_attr / (double) total_num_facts)*
- // sum_attr;
- // }
- // getContinuousGain(facts, split_facts.subList(fromIndex,
- // centerIndex), begin_index, middle_index,
- // facts_in_class1, attributeToSplit);
- //
- // getContinuousGain(facts, split_facts.subList(centerIndex,
- // toIndex), middle_index+1, end_index,
- // facts_in_class2, attributeToSplit);
- f1_index = split_ite.next().intValue();
- f1 = facts.get(f1_index);
- }
-
- List<?> targetValues = getPossibleValues(target);
- // Hashtable<Object, Integer> facts_in_class = new Hashtable<Object,
- // Integer>(targetValues.size());
-
- return 1.0;
- }
-
// *OPT* public double getInformation(List<FactSet> facts) {
public Hashtable<Object, Integer> getStatistics(List<Fact> facts, String target) {
@@ -505,45 +108,8 @@
}
return sum;
}
-
- public double getAttrInformation( Hashtable<Object, Hashtable<Object, Integer>> facts_of_attribute, int fact_size) {
-
- Collection<Object> attributeValues = facts_of_attribute.keySet();
- String attr_sum = Util.getSum();
- double sum = 0.0;
- for (Object attr : attributeValues) {
- int total_num_attr = facts_of_attribute.get(attr).get(attr_sum).intValue();
- //double sum_attr = 0.0;
- if (total_num_attr > 0) {
- sum += ((double) total_num_attr / (double) fact_size)*
- getInformation(facts_of_attribute.get(attr), total_num_attr);
- }
- }
- return sum;
- }
- public double getInformation(Hashtable<Object, Integer> facts_in_class, int total_num_facts) {
- // List<?> targetValues = getPossibleValues(this.target);
- // Hashtable<Object, Integer> facts_in_class = getStatistics(facts,
- // getTarget()); //, targetValues);
- Collection<Object> targetValues = facts_in_class.keySet();
- double sum = 0;
- for (Object key : targetValues) {
- int num_in_class = facts_in_class.get(key).intValue();
- // System.out.println("num_in_class : "+ num_in_class + " key "+ key
- // + " and the total num "+ total_num_facts);
- double prob = (double) num_in_class / (double) total_num_facts;
-
- // double log2= Util.log2(prob);
- // double plog2p= prob*log2;
- sum += (prob == 0.0) ? 0.0 : -1 * prob * Util.log2(prob);
- // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"
- // where the sum: "+sum);
- }
- return sum;
- }
-
public void setTarget(String targetField) {
target = targetField;
attrsToClassify.remove(target);
@@ -560,6 +126,7 @@
return domainSet.get(fieldName).getValues();
}
+
public List<String> getAttributes() {
return attrsToClassify;
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -3,7 +3,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.Comparator;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
@@ -17,6 +16,7 @@
import dt.memory.FactSet;
import dt.memory.OOFactSet;
import dt.memory.Domain;
+import dt.tools.FactProcessor;
import dt.tools.Util;
public class C45TreeBuilder implements DecisionTreeBuilder {
@@ -166,8 +166,7 @@
// LeafNode(facts.get(0).getFact(0).getFieldValue(target));
LeafNode classifiedNode = new LeafNode(
dt.getDomain(dt.getTarget()), winner);
- classifiedNode.setRank((double) facts.size()
- / (double) num_fact_processed);
+ classifiedNode.setRank((double) facts.size()/(double) num_fact_processed);
return classifiedNode;
}
@@ -181,33 +180,29 @@
return noAttributeLeftNode;
}
- /* id3 starts */
- String chosenAttribute = attributeWithGreatestGain(dt, facts, stats,
- attributeNames);
+ /* choosing the attribute for the branching starts */
+// String chosenAttribute = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+// List<?> categorization = dt.getPossibleValues(chosenAttribute);
+ Domain<?> choosenDomain = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+ System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
- System.out.println(Util.ntimes("*", 20) + " 1st best attr: "
- + chosenAttribute);
+ TreeNode currentNode = new TreeNode(choosenDomain);
- TreeNode currentNode = new TreeNode(dt.getDomain(chosenAttribute));
- // ConstantDecisionTree m = majorityValue(ds);
- /* the majority */
+ Hashtable<Object, List<Fact>> filtered_facts = null;
- List<?> attributeValues = dt.getPossibleValues(chosenAttribute);
- Hashtable<Object, List<Fact>> filtered_facts = splitFacts(facts,
- chosenAttribute, attributeValues);
+ if (choosenDomain.isDiscrete()) {
+ filtered_facts = FactProcessor.splitFacts_disc(facts, choosenDomain.getName(), choosenDomain.getValues());
+ } else {
+ filtered_facts = FactProcessor.splitFacts_cont(facts, choosenDomain);
+ }
dt.FACTS_READ += facts.size();
- // if (FUNC_CALL ==5) {
- // System.out.println("FUNC_CALL:" +FUNC_CALL);
- // System.exit(0);
- // }
- for (int i = 0; i < attributeValues.size(); i++) {
+ for (Object value : filtered_facts.keySet()) {
/* split the last two class at the same time */
- Object value = attributeValues.get(i);
ArrayList<String> attributeNames_copy = new ArrayList<String>(
attributeNames);
- attributeNames_copy.remove(chosenAttribute);
+ attributeNames_copy.remove(choosenDomain.getName());
if (filtered_facts.get(value).isEmpty()) {
/* majority !!!! */
@@ -224,137 +219,9 @@
return currentNode;
}
-
- // String chooseAttribute(List<FactSet> facts, List<String> attrs) {
- public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts,
- Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
-
- double dt_info = dt.getInformation(facts_in_class, facts.size());
- double greatestGain = 0.0;
- String attributeWithGreatestGain = attrs.get(0);
- for (String attr : attrs) {
- double gain = 0;
- if (dt.getDomain(attr).isDiscrete()) {
- gain = dt_info - dt.getGain(facts, attr);
- } else {
- /* 1. sort the values */
- int begin_index = 0;
- int end_index = facts.size();
- Collections.sort(facts,
- new FactNumericAttributeComparator(attr));
- List<Integer> splits = getSplitPoints(facts, dt.getTarget());
- gain = dt_info
- - dt.getContinuousGain(facts, splits, begin_index,
- end_index, facts_in_class, attr);
- // gain = dt_info - dt.getContinuousGain(facts, facts_in_class,
- // attr);
- }
-
- System.out.println("Attribute: " + attr + " the gain: " + gain);
- if (gain > greatestGain) {
- greatestGain = gain;
- attributeWithGreatestGain = attr;
- }
- }
-
- return attributeWithGreatestGain;
- }
-
- /*
- * id3 uses that function because it can not classify continuous attributes
- */
-
- public String attributeWithGreatestGain_discrete(DecisionTree dt,
- List<Fact> facts, Hashtable<Object, Integer> facts_in_class,
- List<String> attrs) {
-
- double dt_info = dt.getInformation(facts_in_class, facts.size());
- double greatestGain = 0.0;
- String attributeWithGreatestGain = attrs.get(0);
- for (String attr : attrs) {
- double gain = 0;
- if (!dt.getDomain(attr).isDiscrete()) {
- System.err.println("Ignoring the attribute:" + attr
- + " the id3 can not classify continuous attributes");
- continue;
- } else {
- gain = dt_info - dt.getGain(facts, attr);
- }
- System.out.println("Attribute: " + attr + " the gain: " + gain);
- if (gain > greatestGain) {
- greatestGain = gain;
- attributeWithGreatestGain = attr;
- }
-
- }
-
- return attributeWithGreatestGain;
- }
-
- private List<Integer> getSplitPoints(List<Fact> facts, String target) {
- List<Integer> splits = new ArrayList<Integer>();
- Iterator<Fact> it_f = facts.iterator();
- Fact f1 = it_f.next();
- int index = 0;
- while (it_f.hasNext()) {
- Fact f2 = it_f.next();
- if (f1.getFieldValue(target) != f2.getFieldValue(target))
- splits.add(Integer.valueOf(index));
-
- f1 = f2;
- index++;
- }
- return splits;
- }
-
- public Hashtable<Object, List<Fact>> splitFacts(List<Fact> facts,
- String attributeName, List<?> attributeValues) {
- Hashtable<Object, List<Fact>> factLists = new Hashtable<Object, List<Fact>>(
- attributeValues.size());
- for (Object v : attributeValues) {
- factLists.put(v, new ArrayList<Fact>());
- }
- for (Fact f : facts) {
- factLists.get(f.getFieldValue(attributeName)).add(f);
- }
- return factLists;
- }
-
- public void testEntropy(DecisionTree dt, List<Fact> facts) {
- Hashtable<Object, Integer> facts_in_class = dt.getStatistics(facts, dt
- .getTarget());// , targetValues
- double initial_info = dt.getInformation(facts_in_class, facts.size()); // entropy
- // value
-
- System.out.println("initial_information: " + initial_info);
-
- String first_attr = attributeWithGreatestGain(dt, facts,
- facts_in_class, dt.getAttributes());
-
- System.out.println("best attr: " + first_attr);
- }
-
public int getNumCall() {
return FUNC_CALL;
}
- private class FactNumericAttributeComparator implements Comparator<Fact> {
- private String attr_name;
- public FactNumericAttributeComparator(String _attr_name) {
- attr_name = _attr_name;
- }
-
- public int compare(Fact f0, Fact f1) {
- Number n0 = (Number) f0.getFieldValue(attr_name);
- Number n1 = (Number) f1.getFieldValue(attr_name);
- if (n0.doubleValue() < n1.doubleValue())
- return -1;
- else if (n0.doubleValue() > n1.doubleValue())
- return 1;
- else
- return 0;
- }
- }
-
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilderMT.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -227,7 +227,7 @@
}
/* id3 starts */
- String chosenAttribute = attributeWithGreatestGain(dt, facts, stats, attributeNames);
+ String chosenAttribute = Entropy.chooseAttribute(dt, facts, stats, attributeNames);
System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
@@ -279,24 +279,24 @@
return currentNode;
}
- //String chooseAttribute(List<FactSet> facts, List<String> attrs) {
- public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts, Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+// //String chooseAttribute(List<FactSet> facts, List<String> attrs) {
+// public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts, Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
+//
+// double dt_info = dt.getInformation(facts_in_class, facts.size());
+// double greatestGain = 0.0;
+// String attributeWithGreatestGain = attrs.get(0);
+// for (String attr : attrs) {
+// double gain = dt_info - dt.getGain(facts, attr);
+// System.out.println("Attribute: "+attr +" the gain: "+gain);
+// if (gain > greatestGain) {
+// greatestGain = gain;
+// attributeWithGreatestGain = attr;
+// }
+// }
+//
+// return attributeWithGreatestGain;
+// }
- double dt_info = dt.getInformation(facts_in_class, facts.size());
- double greatestGain = 0.0;
- String attributeWithGreatestGain = attrs.get(0);
- for (String attr : attrs) {
- double gain = dt_info - dt.getGain(facts, attr);
- System.out.println("Attribute: "+attr +" the gain: "+gain);
- if (gain > greatestGain) {
- greatestGain = gain;
- attributeWithGreatestGain = attr;
- }
- }
-
- return attributeWithGreatestGain;
- }
-
public Hashtable<Object, List<Fact> > splitFacts(List<Fact> facts, String attributeName,
List<?> attributeValues) {
Hashtable<Object, List<Fact> > factLists = new Hashtable<Object, List<Fact> >(attributeValues.size());
@@ -308,18 +308,6 @@
}
return factLists;
}
-
- public void testEntropy(DecisionTree dt, List<Fact> facts) {
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());
-
- double initial_info = dt.getInformation(stats, facts.size()); //entropy value
-
- System.out.println("initial_information: "+ initial_info);
-
- String first_attr = attributeWithGreatestGain(dt, facts, stats, dt.getAttributes());
-
- System.out.println("best attr: "+ first_attr);
- }
public int getNumCall() {
return FUNC_CALL;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -4,7 +4,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.Comparator;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
@@ -18,6 +17,7 @@
import dt.memory.FactSet;
import dt.memory.OOFactSet;
import dt.memory.Domain;
+import dt.tools.FactProcessor;
import dt.tools.Util;
public class IDTreeBuilder implements DecisionTreeBuilder {
@@ -145,85 +145,7 @@
//List<?> targetValues = dt.getPossibleValues(dt.getTarget());
Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
Collection<Object> targetValues = stats.keySet();
- int winner_vote = 0;
- int num_supporters = 0;
- Object winner = null;
- for (Object key: targetValues) {
-
- int num_in_class = stats.get(key).intValue();
- if (num_in_class>0)
- num_supporters ++;
- if (num_in_class > winner_vote) {
- winner_vote = num_in_class;
- winner = key;
- }
- }
-
- /* if all elements are classified to the same value */
- if (num_supporters == 1) {
- //*OPT* return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
- LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
- return classifiedNode;
- }
-
- /* if there is no attribute left in order to continue */
- if (attributeNames.size() == 0) {
- /* an heuristic of the leaf classification*/
- LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- noAttributeLeftNode.setRank((double)winner_vote/(double)num_fact_processed);
- return noAttributeLeftNode;
- }
-
- /* id3 starts */
- String chosenAttribute = attributeWithGreatestGain_discrete(dt, facts, stats, attributeNames);
-
- System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
-
- TreeNode currentNode = new TreeNode(dt.getDomain(chosenAttribute));
- //ConstantDecisionTree m = majorityValue(ds);
- /* the majority */
-
- List<?> attributeValues = dt.getPossibleValues(chosenAttribute);
- Hashtable<Object, List<Fact> > filtered_facts = splitFacts(facts, chosenAttribute, attributeValues);
- dt.FACTS_READ += facts.size();
-
-// if (FUNC_CALL ==5) {
-// System.out.println("FUNC_CALL:" +FUNC_CALL);
-// System.exit(0);
-// }
- for (int i = 0; i < attributeValues.size(); i++) {
- /* split the last two class at the same time */
- Object value = attributeValues.get(i);
-
- ArrayList<String> attributeNames_copy = new ArrayList<String>(attributeNames);
- attributeNames_copy.remove(chosenAttribute);
-
- if (filtered_facts.get(value).isEmpty()) {
- /* majority !!!! */
- LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- majorityNode.setRank(0.0);
- currentNode.addNode(value, majorityNode);
- } else {
- TreeNode newNode = id3(dt, filtered_facts.get(value), attributeNames_copy);
- currentNode.addNode(value, newNode);
- }
- }
-
- return currentNode;
- }
-
-private TreeNode c4_5(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
-
- FUNC_CALL ++;
- if (facts.size() == 0) {
- throw new RuntimeException("Nothing to classify, factlist is empty");
- }
- /* let's get the statistics of the results */
- //List<?> targetValues = dt.getPossibleValues(dt.getTarget());
- Hashtable<Object, Integer> stats = dt.getStatistics(facts, dt.getTarget());//targetValues
- Collection<Object> targetValues = stats.keySet();
int winner_vote = 0;
int num_supporters = 0;
Object winner = null;
@@ -255,7 +177,8 @@
}
/* id3 starts */
- String chosenAttribute = attributeWithGreatestGain(dt, facts, stats, attributeNames);
+ String chosenAttribute = Entropy.chooseAttribute(dt, facts, stats, attributeNames);
+ //attributeWithGreatestGain_discrete(dt, facts, stats, attributeNames);
System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ chosenAttribute);
@@ -264,7 +187,8 @@
/* the majority */
List<?> attributeValues = dt.getPossibleValues(chosenAttribute);
- Hashtable<Object, List<Fact> > filtered_facts = splitFacts(facts, chosenAttribute, attributeValues);
+ Hashtable<Object, List<Fact> > filtered_facts =
+ FactProcessor.splitFacts_disc(facts, chosenAttribute, attributeValues);
dt.FACTS_READ += facts.size();
@@ -292,130 +216,8 @@
return currentNode;
}
-
- //String chooseAttribute(List<FactSet> facts, List<String> attrs) {
- public String attributeWithGreatestGain(DecisionTree dt, List<Fact> facts,
- Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
- double dt_info = dt.getInformation(facts_in_class, facts.size());
- double greatestGain = 0.0;
- String attributeWithGreatestGain = attrs.get(0);
- for (String attr : attrs) {
- double gain = 0;
- if (dt.getDomain(attr).isDiscrete()) {
- gain = dt_info - dt.getGain(facts, attr);
- } else {
- /* 1. sort the values */
- int begin_index = 0;
- int end_index = facts.size();
- Collections.sort(facts, new FactNumericAttributeComparator(attr));
- List<Integer> splits = getSplitPoints(facts, dt.getTarget());
- gain = dt_info - dt.getContinuousGain(facts, splits,
- begin_index, end_index,
- facts_in_class, attr);
- //gain = dt_info - dt.getContinuousGain(facts, facts_in_class, attr);
- }
-
- System.out.println("Attribute: "+attr +" the gain: "+gain);
- if (gain > greatestGain) {
- greatestGain = gain;
- attributeWithGreatestGain = attr;
- }
- }
-
- return attributeWithGreatestGain;
- }
- /*
- * id3 uses that function because it can not classify continuous attributes
- */
-
- public String attributeWithGreatestGain_discrete(DecisionTree dt, List<Fact> facts,
- Hashtable<Object, Integer> facts_in_class, List<String> attrs) {
-
- double dt_info = dt.getInformation(facts_in_class, facts.size());
- double greatestGain = 0.0;
- String attributeWithGreatestGain = attrs.get(0);
- for (String attr : attrs) {
- double gain = 0;
- if (!dt.getDomain(attr).isDiscrete()) {
- System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
- continue;
- } else {
- gain = dt_info - dt.getGain(facts, attr);
- }
- System.out.println("Attribute: " + attr + " the gain: " + gain);
- if (gain > greatestGain) {
- greatestGain = gain;
- attributeWithGreatestGain = attr;
- }
-
-
- }
-
- return attributeWithGreatestGain;
- }
-
- private List<Integer> getSplitPoints(List<Fact> facts, String target) {
- List<Integer> splits = new ArrayList<Integer>();
- Iterator<Fact> it_f = facts.iterator();
- Fact f1 = it_f.next();
- int index = 0;
- while(it_f.hasNext()){
- Fact f2 = it_f.next();
- if (f1.getFieldValue(target) != f2.getFieldValue(target))
- splits.add(Integer.valueOf(index));
-
- f1= f2;
- index++;
- }
- return splits;
- }
-
-
- public Hashtable<Object, List<Fact> > splitFacts(List<Fact> facts, String attributeName,
- List<?> attributeValues) {
- Hashtable<Object, List<Fact> > factLists = new Hashtable<Object, List<Fact> >(attributeValues.size());
- for (Object v: attributeValues) {
- factLists.put(v, new ArrayList<Fact>());
- }
- for (Fact f : facts) {
- factLists.get(f.getFieldValue(attributeName)).add(f);
- }
- return factLists;
- }
-
- public void testEntropy(DecisionTree dt, List<Fact> facts) {
- Hashtable<Object, Integer> facts_in_class = dt.getStatistics(facts, dt.getTarget());//, targetValues
- double initial_info = dt.getInformation(facts_in_class, facts.size()); //entropy value
-
- System.out.println("initial_information: "+ initial_info);
-
- String first_attr = attributeWithGreatestGain(dt, facts, facts_in_class, dt.getAttributes());
-
- System.out.println("best attr: "+ first_attr);
- }
-
public int getNumCall() {
return FUNC_CALL;
}
-
- private class FactNumericAttributeComparator implements Comparator<Fact> {
- private String attr_name;
-
- public FactNumericAttributeComparator(String _attr_name) {
- attr_name = _attr_name;
- }
-
- public int compare(Fact f0, Fact f1) {
- Number n0 = (Number) f0.getFieldValue(attr_name);
- Number n1 = (Number) f1.getFieldValue(attr_name);
- if (n0.doubleValue() < n1.doubleValue())
- return -1;
- else if (n0.doubleValue() > n1.doubleValue())
- return 1;
- else
- return 0;
- }
- }
-
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,6 +1,7 @@
package dt.memory;
import java.util.ArrayList;
+import java.util.Comparator;
import java.util.List;
public class BooleanDomain implements Domain<Boolean> {
@@ -16,7 +17,15 @@
fValues = new ArrayList<Boolean>();
fValues.add(Boolean.TRUE);
fValues.add(Boolean.FALSE);
+ readingSeq = -1;
}
+
+ public Domain<Boolean> clone() {
+ BooleanDomain dom = new BooleanDomain(fName);
+ dom.constant = constant;
+ dom.readingSeq = readingSeq;
+ return dom;
+ }
public boolean isDiscrete() {
return true;
@@ -95,4 +104,10 @@
}
+ public Comparator<Fact> factComparator() {
+ // TODO Auto-generated method stub
+ System.out.println("BooleanDomain.factComparator() can not be continuous what is going on? ");
+ return null;
+ }
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,5 +1,6 @@
package dt.memory;
+import java.util.Comparator;
import java.util.List;
public interface Domain<T> {
@@ -25,6 +26,10 @@
void setReadingSeq(int readingSeq);
int getReadingSeq();
+
+ Comparator<Fact> factComparator();
+
+ public Domain<T> clone();
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -41,6 +41,10 @@
values.put(its_domain.getName(), value);
}
+ public Domain<?> getDomain(String field_name) {
+ return fields.get(field_name);
+ }
+
public Object getFieldValue(String field_name) {
return values.get(field_name);
}
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactDistribution.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -0,0 +1,69 @@
+package dt.memory;
+
+import java.util.Collection;
+import java.util.Hashtable;
+import java.util.List;
+
+import dt.tools.Util;
+
+public class FactDistribution {
+ String attr_sum = Util.sum();
+
+ Hashtable<Object, Hashtable<Object, Integer>> facts_at_attr;
+
+ private int total_num;
+
+ public FactDistribution(List<?> attributeValues, List<?> targetValues) {
+ facts_at_attr = new Hashtable<Object, Hashtable<Object, Integer>>(attributeValues.size());
+
+ for (Object attr : attributeValues) {
+ facts_at_attr.put(attr, new Hashtable<Object, Integer>(targetValues.size() + 1));
+ for (Object t : targetValues) {
+ facts_at_attr.get(attr).put(t, 0);
+ }
+ facts_at_attr.get(attr).put(attr_sum, 0);
+ }
+
+ }
+
+ public void setTotal(int size) {
+ this.total_num = size;
+ }
+ public int getTotal() {
+ return this.total_num;
+ }
+
+ public Hashtable<Object, Integer> getAttrFor(Object attr_value) {
+ return facts_at_attr.get(attr_value);
+ }
+
+ public int getSumForAttr(Object attr_value) {
+ return facts_at_attr.get(attr_value).get(attr_sum).intValue();
+ }
+
+ public void setSumForAttr(Object attr_value, int total) {
+ facts_at_attr.get(attr_value).put(attr_sum,total);
+ }
+
+ public void setTargetDistForAttr(Object attr_value, Hashtable<Object, Integer> targetDist) {
+ for (Object target: targetDist.keySet())
+ facts_at_attr.get(attr_value).put(target,targetDist.get(target));
+ }
+
+ public void change(Object attrValue, Object targetValue, int i) {
+ int num_1 = facts_at_attr.get(attrValue).get(targetValue).intValue();
+ num_1 += i;
+ facts_at_attr.get(attrValue).put(targetValue, num_1);
+
+ int total_num_1 = facts_at_attr.get(attrValue).get(attr_sum).intValue();
+ total_num_1 += i;
+ facts_at_attr.get(attrValue).put(attr_sum, total_num_1);
+
+ }
+
+ public Collection<Object> getAttributes() {
+ return facts_at_attr.keySet();
+ }
+
+
+}
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -0,0 +1,22 @@
+package dt.memory;
+
+import java.util.Comparator;
+
+public class FactNumericAttributeComparator implements Comparator<Fact> {
+ private String attr_name;
+
+ public FactNumericAttributeComparator(String _attr_name) {
+ attr_name = _attr_name;
+ }
+
+ public int compare(Fact f0, Fact f1) {
+ Number n0 = (Number) f0.getFieldValue(attr_name);
+ Number n1 = (Number) f1.getFieldValue(attr_name);
+ if (n0.doubleValue() < n1.doubleValue())
+ return -1;
+ else if (n0.doubleValue() > n1.doubleValue())
+ return 1;
+ else
+ return 0;
+ }
+}
\ No newline at end of file
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -2,6 +2,7 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
public class LiteralDomain implements Domain<String> {
@@ -17,8 +18,17 @@
fName = _name.trim();
fValues = new ArrayList<String>();
discrete = true;
+ readingSeq = -1;
}
+ public Domain<String> clone() {
+ LiteralDomain dom = new LiteralDomain(fName);
+ dom.constant = constant;
+ dom.discrete = discrete;
+ dom.readingSeq = readingSeq;
+ return dom;
+ }
+
public LiteralDomain(String _name, String[] possibleValues) {
fName = _name;
fValues = Arrays.asList(possibleValues);
@@ -102,4 +112,10 @@
return out;
}
+ public Comparator<Fact> factComparator() {
+ // TODO wee need groupings to be able to discretize the LiteralDomain
+ System.out.println("LiteralDomain.factComparator() is not ready ");
+ return null;
+ }
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,24 +1,42 @@
package dt.memory;
import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
import java.util.List;
public class NumericDomain implements Domain<Number> {
private String fName;
private ArrayList<Number> fValues;
+ private ArrayList<Fact> representatives;
private boolean constant;
private boolean discrete;
private int readingSeq;
+ private Comparator<Fact> fComparator;
public NumericDomain(String _name) {
fName = _name.trim();
fValues = new ArrayList<Number>();
discrete = true;
+ fComparator = new FactNumericAttributeComparator(_name);
+ readingSeq = -1;
}
+
+ public Domain<Number> clone() {
+ NumericDomain dom = new NumericDomain(fName);
+ dom.constant = constant;
+ dom.setDiscrete(discrete);
+ dom.readingSeq = readingSeq;
+ return dom;
+ }
+
public void setDiscrete(boolean d) {
this.discrete = d;
+ if (!this.discrete) {
+ representatives = new ArrayList<Fact>();
+ }
}
public boolean isDiscrete() {
@@ -36,33 +54,15 @@
if (!fValues.contains(value))
fValues.add(value);
} else {
- if (fValues.isEmpty()) {
- fValues.add(value);
- return;
- } else if (fValues.size()==1) {
- if (value.doubleValue() < fValues.get(0).doubleValue()) {
- Number first = fValues.remove(0);
- fValues.add(value);
- fValues.add(first);
- } else if (value.doubleValue() > fValues.get(0).doubleValue()) {
- fValues.add(value);
- }
- return;
- } else {
- if (value.doubleValue() > fValues.get(1).doubleValue()) {
- fValues.remove(1);
- fValues.add(1, value);
- return;
- }
- if (value.doubleValue() < fValues.get(0).doubleValue()) {
- fValues.remove(0);
- fValues.add(0, value);
- return;
- }
- }
+
}
}
+ public void addRepresentative(Fact f) {
+ if (!representatives.contains(f))
+ representatives.add(f);
+ Collections.sort(representatives, this.factComparator());
+ }
public boolean contains(Number value) {
for(Number n: fValues) {
@@ -77,7 +77,11 @@
public List<Number> getValues() {
return fValues;
}
+ public List<Fact> getRepresentatives() {
+ return representatives;
+ }
+
public int hashCode() {
return fName.hashCode();
}
@@ -158,7 +162,9 @@
String out = fName;
return out;
}
-
+ public Comparator<Fact> factComparator() {
+ return fComparator;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -108,9 +108,10 @@
}
Domain<?> fieldDomain;
- if (!domainset.containsKey(f_name))
+ if (!domainset.containsKey(f_name)) {
fieldDomain = DomainFactory.createDomainFromClass(f.getType(), f_name);
- else
+ domainset.put(f_name, fieldDomain);
+ } else
fieldDomain = domainset.get(f_name);
//System.out.println("WorkingMemory.create_factset field "+ field + " fielddomain name "+fieldDomain.getName()+" return_type_name: "+return_type_name+".");
@@ -118,7 +119,7 @@
fieldDomain.setReadingSeq(spec.readingSeq());
fieldDomain.setDiscrete(spec.discrete());
}
- domainset.put(f_name, fieldDomain);
+
newfs.addDomain(f_name, fieldDomain);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -4,7 +4,6 @@
import dt.DecisionTree;
import dt.builder.C45TreeBuilder;
-import dt.builder.DecisionTreeBuilder;
import dt.builder.IDTreeBuilder;
import dt.memory.FactSetFactory;
import dt.memory.WorkingMemory;
@@ -14,10 +13,10 @@
try {
List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
- DecisionTreeBuilder bocuk = new IDTreeBuilder();
+ IDTreeBuilder bocuk = new IDTreeBuilder();
long dt = System.currentTimeMillis();
- String target_attr = Util.getTargetAnnotation(emptyObject.getClass());
+ String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, null);
dt = System.currentTimeMillis() - dt;
@@ -44,7 +43,7 @@
C45TreeBuilder bocuk = new C45TreeBuilder();
long dt = System.currentTimeMillis();
- String target_attr = Util.getTargetAnnotation(emptyObject.getClass());
+ String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, null);
dt = System.currentTimeMillis() - dt;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -161,6 +161,30 @@
return element;
}
+
+ public static String getTargetAnnotation(Class<? extends Object> classObj) {
+
+ Field [] element_fields = classObj.getDeclaredFields();
+ for( Field f: element_fields) {
+ String f_name = f.getName();
+ Class<?>[] f_class = {f.getType()};
+ if (Util.isSimpleType(f_class)) {
+ Annotation[] annotations = f.getAnnotations();
+
+ // iterate over the annotations to locate the MaxLength constraint if it exists
+ DomainSpec spec = null;
+ for (Annotation a : annotations) {
+ if (a instanceof DomainSpec) {
+ spec = (DomainSpec)a; // here it is !!!
+ if (spec.target())
+ return f_name;
+ }
+ }
+ }
+ }
+ return null;
+ }
+
//read(Class<?> element_class, Collection<Domain<?>> collection, String data, String separator)
public static Object read_(Class<?> element_class, Collection<Domain<?>> domains, String data, String separator) {
@@ -284,7 +308,7 @@
//level++;
// Get a handle to the class of the object.
- cl = (classobj instanceof Class) ? (Class) classobj : classobj
+ cl = (classobj instanceof Class) ? (Class<?>) classobj : classobj
.getClass();
// detect when we've reached out limits. This is particularly
@@ -299,7 +323,7 @@
// process each field in turn.
fields = cl.getDeclaredFields();
for (int i = 0; i < fields.length; i++) {
- Class ctype = fields[i].getType();
+ Class<?> ctype = fields[i].getType();
int mod;
String typeName = null, varName = null;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,12 +1,6 @@
package dt.tools;
-import java.lang.annotation.Annotation;
-import java.lang.reflect.Field;
-import java.util.Hashtable;
-import java.util.List;
-import dt.memory.DomainSpec;
-
public class Util {
public static String ntimes(String s,int n){
@@ -76,31 +70,8 @@
return 2;
}
- public static String getTargetAnnotation(Class<? extends Object> classObj) {
-
- Field [] element_fields = classObj.getDeclaredFields();
- for( Field f: element_fields) {
- String f_name = f.getName();
- Class<?>[] f_class = {f.getType()};
- if (Util.isSimpleType(f_class)) {
- Annotation[] annotations = f.getAnnotations();
-
- // iterate over the annotations to locate the MaxLength constraint if it exists
- DomainSpec spec = null;
- for (Annotation a : annotations) {
- if (a instanceof DomainSpec) {
- spec = (DomainSpec)a; // here it is !!!
- if (spec.target())
- return f_name;
- }
- }
- }
- }
- return null;
- }
-
- public static String getSum() {
+ public static String sum() {
return "sum";
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,6 +1,5 @@
package test;
-import java.util.List;
import dt.DecisionTree;
import dt.builder.IDTreeBuilder;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -46,6 +46,6 @@
System.out.println("Time"+dt+"\n"+bocuksTree);
RulePrinter my_printer = new RulePrinter();
- my_printer.printer(bocuksTree,"id3" , new String("src/id3/rules"+".drl"));
+ my_printer.printer(bocuksTree,"test" , new String("../dt_learning/src/test/rules"+".drl"));
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java 2008-03-31 02:41:23 UTC (rev 19328)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java 2008-03-31 04:09:10 UTC (rev 19329)
@@ -1,6 +1,5 @@
package test;
-
public class RestaurantOld {
More information about the jboss-svn-commits
mailing list