[jboss-svn-commits] JBL Code SVN: r19596 - in labs/jbossrules/contrib/machinelearning/decisiontree/src: dt/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Wed Apr 16 14:16:24 EDT 2008
Author: gizil
Date: 2008-04-16 14:16:24 -0400 (Wed, 16 Apr 2008)
New Revision: 19596
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/DomainFactory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSetFactory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java
Log:
testing the decision tree performance + a bug fix in the discretizer
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -5,6 +5,8 @@
import java.util.List;
import dt.memory.Domain;
+import dt.memory.Fact;
+import dt.tools.Util;
public class DecisionTree {
@@ -24,12 +26,20 @@
/* all attributes that can be used during classification */
private ArrayList<String> attrsToClassify;
+ public DecisionTree() {
+
+ this.domainSet = new Hashtable<String, Domain<?>>();
+ this.attrsToClassify = new ArrayList<String>();
+ }
public DecisionTree(String klass) {
this.className = klass;
this.domainSet = new Hashtable<String, Domain<?>>();
this.attrsToClassify = new ArrayList<String>();
}
+ public void setClassName(String klass) {
+ this.className = klass;
+ }
public void setTarget(String targetField) {
target = targetField;
@@ -77,10 +87,15 @@
public long getNumRead() {
return FACTS_READ;
}
+
+ public Integer test(Fact f) {
+ return this.getRoot().evaluate(f);
+ }
@Override
public String toString() {
- return "Facts scanned " + FACTS_READ + "\n" + root.toString();
+ String out = "Facts scanned " + FACTS_READ + "\n";
+ return out + root.toString();
}
/*
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -1,6 +1,9 @@
package dt;
+import java.util.Comparator;
+
import dt.memory.Domain;
+import dt.memory.Fact;
import dt.tools.Util;
@@ -36,6 +39,31 @@
this.rank = rank;
}
+ public Integer evaluate(Fact f) {
+
+ Domain<?> target_domain = this.getDomain();
+ Fact pseudo_f = new Fact();
+ try {
+ pseudo_f.add(target_domain, this.getValue());
+ Comparator<Fact> targetComp = target_domain.factComparator();
+ if (targetComp.compare(f, pseudo_f) == 0 ) {
+ return Integer.valueOf(1);
+ } else {
+ return Integer.valueOf(0);
+ }
+ } catch (Exception e) {
+
+ System.out.println(Util.ntimes("\n", 10)+"Unknown situation at leafnode: " + this.getValue() + " @ "+ target_domain);
+ e.printStackTrace();
+ // Unknown
+ System.exit(0);
+ return Integer.valueOf(2);
+ }
+
+
+
+ }
+
public String toString(){
return "DECISION -> " + targetValue.toString();
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -1,7 +1,9 @@
package dt;
+import java.util.Collection;
import java.util.Hashtable;
import dt.memory.Domain;
+import dt.memory.Fact;
import dt.tools.Util;
@@ -30,14 +32,74 @@
this.domain = domain;
}
- public Hashtable<Object, TreeNode> getChildren() {
- return children;
+ public Collection<Object> getChildrenKeys() {
+ return children.keySet();
}
+ public boolean containChildKey(Object attr_key) {
+ return children.keySet().contains(attr_key);
+ }
+
+ public TreeNode getChild(Object attr_key) {
+ return children.get(attr_key);
+ }
public void setChildren(Hashtable<Object, TreeNode> children) {
this.children = children;
}
+ public Integer evaluate(Fact f) {
+
+ Domain node_domain = this.getDomain();
+ Object attr_value = f.getFieldValue(node_domain.getName());
+ //
+ try {
+ if (node_domain.isPossible(attr_value)) {
+
+ TreeNode my_node = this.getChild(node_domain.getClass(attr_value));
+
+ if (Util.DEBUG_TEST) {
+ String out = "\nDomain:"+node_domain.getName()+"->";
+ for (Object value: node_domain.getValues()) {
+ out += value+"-";
+ }
+ out = Util.ntimes("$", 5) + out + " SEARCHING for = "+ attr_value + " in "+ node_domain.getName();
+
+ out += "\n KEYS:";
+ for (Object key: this.getChildrenKeys()) {
+
+ out += " "+key +"%"+this.getChild(key).getDomain() + " :";
+ }
+ System.out.print(out);
+ System.out.print(" @myclass:"+node_domain.getClass(attr_value));
+
+ if (my_node instanceof LeafNode)
+ System.out.print(" --> leaf node");
+ else
+ System.out.print(" --> not a leaf node");
+
+
+ }
+ Integer x = my_node.evaluate(f);
+ if (Util.DEBUG_TEST) {
+ System.out.println(" <> TEST RESULT: "+ x);
+ }
+ return my_node.evaluate(f);
+ //return this.getChild(node_domain.getClass(attr_value)).evaluate(f);
+ } else {
+// throw new RuntimeException("no child exists for attribute value "
+// + attr_value);
+ // Unknown situation
+ System.out.println(Util.ntimes("\n", 1)+"Notpossible situation at treenode: " + attr_value + " @ "+ node_domain);
+ return (Integer.valueOf(2));
+ }
+ } catch (Exception e) {
+ System.out.println(Util.ntimes("\n", 1)+"Exception situation at treenode: " + attr_value + " @ "+ node_domain);
+ e.printStackTrace();
+ System.exit(0);
+ return (Integer.valueOf(2));
+ }
+ }
+
public String toString() {
return toString(1, new StringBuffer());
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -34,15 +34,18 @@
@Override
public void run() {
- result = builder.c45(dt, facts, attributeNames);
+ result = builder.train(dt, facts, attributeNames);
currentNode.addNode(value, result);
}
}
MyThread helper;
private int FUNC_CALL = 0;
- protected int num_fact_processed = 0;
+ protected int num_fact_trained = 0;
private ArrayList<Fact> unclassified_facts;
+ private ArrayList<Fact> training_facts;
+ private WorkingMemory global_wm;
+ private List<Domain<?>> domains;
/*
* treebuilder.execute(workingmemory, classtoexecute, attributestoprocess)
@@ -52,83 +55,133 @@
*
* internalprocess(attributestoprocess)
*/
-
- public int getNum_fact_processed() {
- return num_fact_processed;
+ public C45TreeBuilder(WorkingMemory wm) {
+
+ unclassified_facts = new ArrayList<Fact>();
+ training_facts = new ArrayList<Fact>();
+ global_wm = wm;
+ domains = new ArrayList<Domain<?>>();
+
}
-
- public void setNum_fact_processed(int num_fact_processed) {
- this.num_fact_processed = num_fact_processed;
- }
-
- public DecisionTree build(WorkingMemory wm, Class<?> klass,
- String targetField, List<String> workingAttributes) {
+
+ public C45TreeBuilder() {
unclassified_facts = new ArrayList<Fact>();
- DecisionTree dt = new DecisionTree(klass.getName());
- // **OPT List<FactSet> facts = new ArrayList<FactSet>();
- ArrayList<Fact> facts = new ArrayList<Fact>();
+ training_facts = new ArrayList<Fact>();
+ domains = new ArrayList<Domain<?>>();
+
+ }
+
+
+ private void setKlass(Class<?> klass) {
+ Iterator<FactSet> it_fs = global_wm.getFactsets();
FactSet klass_fs = null;
- Iterator<FactSet> it_fs = wm.getFactsets();
while (it_fs.hasNext()) {
FactSet fs = it_fs.next();
if (fs instanceof OOFactSet) {
if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
// **OPT facts.add(fs);
- fs.assignTo(facts); // adding all facts of fs to "facts
+ fs.assignTo(training_facts); // adding all facts of fs to "facts
}
+ } else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
+ fs.assignTo(training_facts); // adding all facts of fs to "facts"
+
+ klass_fs = fs;
+ break;
}
if (klass.getName() == fs.getClassName()) {
klass_fs = fs;
}
}
- dt.FACTS_READ += facts.size();
-
- setNum_fact_processed(facts.size());
-
+
+ for (Domain<?> d : klass_fs.getDomains())
+ domains.add(d);
+ }
+
+ private void init(DecisionTree dt, String targetField, List<String> workingAttributes) {
+ dt.setTarget(targetField);
if (workingAttributes != null)
for (String attr : workingAttributes) {
- dt.addDomain(klass_fs.getDomain(attr));
+ dt.addDomain(global_wm.getDomain(attr));
}
- else
- for (Domain<?> d : klass_fs.getDomains())
- dt.addDomain(d);
-
- dt.setTarget(targetField);
-
+ else {
+ for (Domain<?> d : domains) {
+ dt.addDomain(d);
+ }
+ }
+
+ }
+
+ /* building with a training and test */
+ public DecisionTree build(Class<?> klass, String targetField, List<String> workingAttributes) {
+ /* gets the facts the decision tree is eligible */
+ setKlass(klass);
+
+ DecisionTree dt = new DecisionTree(klass.getName());
+ init(dt, targetField, workingAttributes);
+
+
+ DecisionTree best_dt = new DecisionTree(klass.getName());
+ init(dt, targetField, workingAttributes);
+
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
+
+
+ dt.FACTS_READ += training_facts.size();
+ /* you must set this when the training called the first time */
+ setNum_fact_trained(training_facts.size());
- TreeNode root = c45(dt, facts, attrs);
+ //while ()
+ TreeNode root = train(dt, training_facts, attrs);
dt.setRoot(root);
+
+ System.out.println(Util.ntimes("\n", 2)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 2));
+ List<Integer> evaluation = test(dt, training_facts.subList(339, 340));
+ System.out.println("TESTING results: Mistakes "+ evaluation.get(0));
+ System.out.println("TESTING results: Corrects "+ evaluation.get(1));
+ System.out.println("TESTING results: Unknown "+ evaluation.get(2));
+ if (evaluation.get(1) == training_facts.size()) {
+ best_dt.setRoot(root);
+ }
+
return dt;
}
+
- public DecisionTree build(WorkingMemory wm, String klass,
+ public DecisionTree build(WorkingMemory wm, Class<?> klass,
String targetField, List<String> workingAttributes) {
+
unclassified_facts = new ArrayList<Fact>();
- DecisionTree dt = new DecisionTree(klass);
+ DecisionTree dt = new DecisionTree(klass.getName());
// **OPT List<FactSet> facts = new ArrayList<FactSet>();
ArrayList<Fact> facts = new ArrayList<Fact>();
FactSet klass_fs = null;
Iterator<FactSet> it_fs = wm.getFactsets();
while (it_fs.hasNext()) {
FactSet fs = it_fs.next();
- if (klass == fs.getClassName()) {
- // **OPT facts.add(fs);
+ if (fs instanceof OOFactSet) {
+ if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
+ // **OPT facts.add(fs);
+ fs.assignTo(facts); // adding all facts of fs to "facts
+ }
+ } else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
fs.assignTo(facts); // adding all facts of fs to "facts"
klass_fs = fs;
break;
}
+ if (klass.getName() == fs.getClassName()) {
+ klass_fs = fs;
+ }
}
dt.FACTS_READ += facts.size();
- setNum_fact_processed(facts.size());
+ setNum_fact_trained(facts.size());
+
if (workingAttributes != null)
for (String attr : workingAttributes) {
- //System.out.println("Bok degil " + attr);
dt.addDomain(klass_fs.getDomain(attr));
}
else
@@ -140,13 +193,13 @@
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
- TreeNode root = c45(dt, facts, attrs);
+ TreeNode root = train(dt, facts, attrs);
dt.setRoot(root);
return dt;
}
-
- public DecisionTree build_test(WorkingMemory wm, String klass,
+
+/* public DecisionTree build(WorkingMemory wm, String klass,
String targetField, List<String> workingAttributes) {
unclassified_facts = new ArrayList<Fact>();
DecisionTree dt = new DecisionTree(klass);
@@ -181,10 +234,13 @@
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
+ TreeNode root = c45(dt, facts, attrs);
+ dt.setRoot(root);
+
return dt;
- }
+ }*/
- private TreeNode c45(DecisionTree dt, List<Fact> facts,
+ public TreeNode train(DecisionTree dt, List<Fact> facts,
List<String> attributeNames) {
FUNC_CALL++;
@@ -205,7 +261,7 @@
if (stats.getNum_supported_target_classes() == 1) {
LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
- classifiedNode.setRank((double) facts.size()/(double) getNum_fact_processed());
+ classifiedNode.setRank((double) facts.size()/(double) getNum_fact_trained());
classifiedNode.setNumSupporter(facts.size());
return classifiedNode;
@@ -216,7 +272,7 @@
/* an heuristic of the leaf classification */
Object winner = stats.getThe_winner_target_class();
LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- noAttributeLeftNode.setRank((double) stats.getVoteFor(winner)/ (double) num_fact_processed);
+ noAttributeLeftNode.setRank((double) stats.getVoteFor(winner)/ (double) num_fact_trained);
noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
/* we need to know how many guys cannot be classified and who these guys are */
@@ -257,16 +313,51 @@
majorityNode.setNumSupporter(filtered_facts.get(value).size());
currentNode.addNode(value, majorityNode);
} else {
- TreeNode newNode = c45(dt, filtered_facts.get(value), attributeNames_copy);
+ TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
currentNode.addNode(value, newNode);
}
}
return currentNode;
}
+
+
+
+ public List<Integer> test(DecisionTree dt, List<Fact> facts) {
+ /*
+ * false | true | unknown
+ * | | |
+ * 0 1 2
+ */
+ List<Integer> results = new ArrayList<Integer>(3);
+ for (int i=0; i < 3; i ++) {
+ results.add(new Integer(0));
+ }
+
+ int i = 0;
+ for (Fact f : facts) {
+ if (Util.DEBUG_TEST) {
+ System.out.println(Util.ntimes("#\n", 5)+i+ " <START> TEST: f="+ f);
+ //System.exit(0);
+ }
+ Integer result = dt.test(f);
+
+ results.set(result, Integer.valueOf(results.get(result) + 1));
+ i ++;
+ }
+ return results;
+
+ }
+
public int getNumCall() {
return FUNC_CALL;
}
+ public int getNum_fact_trained() {
+ return num_fact_trained;
+ }
+ public void setNum_fact_trained(int num_fact_processed) {
+ this.num_fact_trained = num_fact_processed;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -3,15 +3,19 @@
import java.util.List;
import dt.DecisionTree;
+import dt.TreeNode;
+import dt.memory.Fact;
import dt.memory.WorkingMemory;
public interface DecisionTreeBuilder {
DecisionTree build(WorkingMemory wm, Class<?> klass, String targetField, List<String> workingAttributes);
+
+ public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames);
+ public List<Integer> test(DecisionTree dt, List<Fact> facts);
+ //DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
- DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
-
- int getNum_fact_processed();
- void setNum_fact_processed(int num);
+ int getNum_fact_trained();
+ void setNum_fact_trained(int num);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Discretizer.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -24,6 +24,7 @@
private int maxDepth = 1;
private Domain binaryDomain;
+
Discretizer(Domain<?> _targetDomain, List<Fact> _facts, FactTargetDistribution _facts_in_class) {
this.facts = new ArrayList<Fact>(_facts.size());
facts.addAll(_facts);
@@ -61,8 +62,9 @@
SplitPoint last_point = new SplitPoint(facts.size()-1, (Number) facts.get(facts.size()-1).getFieldValue(splitAttr));
split_indices.add(last_point);
- find_a_split(0, facts.size(), getMaxDepth(), distribution, split_indices);
- Collections.sort(split_indices, Discretizer.getSplitComparator());
+ SplitPoint foundPoint = find_a_split(0, facts.size(), getMaxDepth(), distribution, split_indices);
+ if (foundPoint != null)
+ Collections.sort(split_indices, Discretizer.getSplitComparator());
List<Integer> splits = new ArrayList<Integer>(split_indices.size());
for (SplitPoint sp: split_indices) {
@@ -87,29 +89,46 @@
List<SplitPoint> split_points) {
if (facts.size() <= 1) {
- System.out.println("fact.size <=1 returning 0.0....");
+ if (Util.DEBUG) System.out.println("fact.size <=1 returning 0.0....");
return null;
}
facts_in_class.evaluateMajority();
if (facts_in_class.getNum_supported_target_classes()==1) {
- System.out.println("getNum_supported_target_classes=1 returning 0.0....");
+ if (Util.DEBUG) System.out.println("getNum_supported_target_classes=1 returning 0.0....");
return null; //?
}
if (depth == 0) {
- System.out.println("depth == 0 returning 0.0....");
+ if (Util.DEBUG) System.out.println("depth == 0 returning 0.0....");
return null;
}
String targetAttr = targetDomain.getName();
List<?> targetValues = targetDomain.getValues();
- if (Util.DEBUG) {
- System.out.println("Discretizer.find_a_split() attributeToSplit? " + splitAttr);
- for(int index =begin_index; index < end_index; index ++) {
- Fact f= facts.get(index);
- System.out.println("entropy.info_cont() SORTING: "+index+" attr "+splitAttr+ " "+ f );
+
+ if (Util.DEBUG) System.out.println("Discretizer.find_a_split() attributeToSplit? " + splitAttr);
+ int num_split_points = 0;
+
+ Fact fact_ = facts.get(begin_index);
+ Comparator<Fact> targetComp_ = fact_.getDomain(targetAttr).factComparator();
+ Comparator<Fact> attrComp_ = fact_.getDomain(splitAttr).factComparator();
+ if (Util.DEBUG) System.out.println("Discretizer.find_a_split() SORTING: "+0+" attr "+splitAttr+ " "+ fact_ );
+ for(int index =begin_index+1; index < end_index; index ++) {
+ Fact fact_2= facts.get(index);
+ //System.out.println("test != " + attrComp_.compare(fact_, fact_2) +" of "+ fact_.getFieldValue(splitAttr)+ " and "+ fact_2.getFieldValue(splitAttr));
+
+ if ( targetComp_.compare(fact_, fact_2)!=0 && attrComp_.compare(fact_, fact_2)!=0) {
+ num_split_points++;
+
+ if (Util.DEBUG) System.out.println("Discretizer.find_a_split() SORTING: "+index+" attr "+splitAttr+ " "+ fact_2 );
+ //break; //you can check if there is at least one
}
+ fact_ = fact_2;
}
+ if (num_split_points ==0) {
+ return null; //there is no possible split point
+ }
+
/* initialize the distribution */
Object key0 = Integer.valueOf(0);
Object key1 = Integer.valueOf(1);
@@ -117,9 +136,7 @@
// keys.add(key0);
// keys.add(key1);
-
-
- FactAttrDistribution facts_at_attribute = new FactAttrDistribution(binaryDomain, targetDomain);
+ FactAttrDistribution facts_at_attribute = new FactAttrDistribution(getBinaryDomain(), targetDomain);
facts_at_attribute.setTotal(facts.size());
facts_at_attribute.setTargetDistForAttr(key1, facts_in_class);
facts_at_attribute.setSumForAttr(key1, facts.size());
@@ -134,11 +151,12 @@
Fact f1 = f_ite.next();
Comparator<Fact> targetComp = f1.getDomain(targetAttr).factComparator();
+ Comparator<Fact> attrComp = f1.getDomain(splitAttr).factComparator();
+
if (Util.DEBUG) System.out.println("\nentropy.info_cont() SEARCHING: "+begin_index+" attr "+splitAttr+ " "+ f1 );
while (f_ite.hasNext() && index<end_index) {/* 2. Look for potential cut-points. */
Fact f2 = f_ite.next();
- if (Util.DEBUG) System.out.println("entropy.info_cont() SEARCHING: "+(index)+" attr "+splitAttr+ " "+ f2 );
Object targetKey = f1.getFieldValue(targetAttr);
@@ -157,7 +175,10 @@
* where the classes change from A to B or vice versa).
*/
- if ( targetComp.compare(f1, f2)!=0) {
+ if ( targetComp.compare(f1, f2)!=0 && attrComp.compare(f1, f2)!=0) {
+
+ if (Util.DEBUG) System.out.println("entropy.info_cont() SEARCHING: "+(index)+" attr "+splitAttr+ " "+ f2 );
+
// the cut point
Number cp_i = (Number) f1.getFieldValue(splitAttr);
Number cp_i_next = (Number) f2.getFieldValue(splitAttr);
@@ -210,6 +231,9 @@
return bestPoint;
}
+ private Domain<?> getBinaryDomain() {
+ return binaryDomain;
+ }
public static Comparator<SplitPoint> getSplitComparator() {
return new SplitComparator();
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -29,6 +29,9 @@
List<Integer> split_indices = null;
Domain<?> targetDomain = dt.getDomain(dt.getTarget());
for (String attr : attrs) {
+ if (attr.equalsIgnoreCase("c2"))
+ Util.DEBUG = false;
+
// if (attr.equalsIgnoreCase(targetDomain.getName()))
// continue;
System.out.println("Which attribute to try: "+ attr);
@@ -47,19 +50,24 @@
Discretizer visitor = new Discretizer(targetDomain, facts, facts_in_class);
attrDomain = dt.getDomain(attr).clone();
split_indices = visitor.findSplits(attrDomain);
- int index = 0;
- for (Integer i: split_indices) {
- System.out.print("Split indices:"+ i);
- System.out.print(" domain "+attrDomain.getValues().get(index));
- System.out.print(","+attrDomain.getIndices().get(index));
- System.out.println(" fact "+visitor.getSortedFact(i));
- index++;
+ if (Util.DEBUG) {
+ int index = 0;
+ for (Integer i: split_indices) {
+ System.out.print("Split indices:"+ i);
+ System.out.print(" domain "+attrDomain.getValues().get(index));
+ System.out.print(","+attrDomain.getIndices().get(index));
+ System.out.println(" fact "+visitor.getSortedFact(i));
+ index++;
+ }
}
- gain = dt_info - calc_info_contattr(visitor.getSortedFacts(), attrDomain, targetDomain, split_indices);
-
+ if (split_indices.size()==1) {
+ gain = 0.0;
+ } else {
+ gain = dt_info - calc_info_contattr(visitor.getSortedFacts(), attrDomain, targetDomain, split_indices);
+ }
}
- System.out.println(Util.ntimes("\n",3)+Util.ntimes("?",10)+" ATTR TRIAL "+attr + " the gain "+gain + " info "+ dt_info );
+ if (Util.DEBUG) System.out.println("\nATTR TRIAL "+Util.ntimes("?",10)+attr + " the gain "+gain + " info "+ dt_info );
if (gain > greatestGain) {
greatestGain = gain;
@@ -70,6 +78,9 @@
System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
}
+
+ if (attr.equalsIgnoreCase("c2"))
+ Util.DEBUG = false;
}
return bestDomain;
@@ -104,10 +115,18 @@
List<?> targetValues = targetDomain.getValues();
if (Util.DEBUG) {
System.out.println("entropy.info_cont() attributeToSplit? " + splitAttr);
- int f_i=0;
- for(Fact f: facts) {
- System.out.println("entropy.info_cont() SORTING: "+f_i+" attr "+splitAttr+ " "+ f );
- f_i++;
+
+ Fact fact_ = facts.get(begin_index);
+ Comparator<Fact> targetComp_ = fact_.getDomain(targetAttr).factComparator();
+ Comparator<Fact> attrComp_ = fact_.getDomain(splitAttr).factComparator();
+ System.out.println("entropy.info_cont() SORTING: "+0+" attr "+splitAttr+ " "+ fact_ );
+ for(int index =begin_index+1; index < end_index; index ++) {
+ Fact fact_2= facts.get(index);
+ //System.out.println("test != " + attrComp_.compare(fact_, fact_2) +" of "+ fact_.getFieldValue(splitAttr)+ " and "+ fact_2.getFieldValue(splitAttr));
+ if ( targetComp_.compare(fact_, fact_2)!=0 && attrComp_.compare(fact_, fact_2)!=0) {
+ System.out.println("entropy.info_cont() SORTING: "+index+" attr "+splitAttr+ " "+ fact_2 );
+ }
+ fact_ = fact_2;
}
}
/* initialize the distribution */
@@ -267,12 +286,13 @@
Domain<?> targetDomain = dt.getDomain(target);
for (String attr : attrs) {
double gain = 0;
- if (!dt.getDomain(attr).isDiscrete()) {
- System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
- continue;
- } else {
- gain = dt_info - info_attr(facts, dt.getDomain(attr), targetDomain);
- }
+// if (!dt.getDomain(attr).isDiscrete()) {
+// System.err.println("Ignoring the attribute:" +attr+ " the id3 can not classify continuous attributes");
+// continue;
+// } else {
+// gain = dt_info - info_attr(facts, dt.getDomain(attr), targetDomain);
+// }
+ gain = dt_info - info_attr(facts, dt.getDomain(attr), targetDomain);
if (Util.DEBUG) System.out.println("Attribute: " + attr + " the gain: " + gain);
if (gain > greatestGain) {
greatestGain = gain;
@@ -320,11 +340,12 @@
List<?> splitValues = splitDomain.getValues();
String targetAttr = targetDomain.getName();
+ if (Util.DEBUG) {
+ System.out.println("Numof classes in domain "+ splitDomain.getValues().size());
+ System.out.println("Numof splits in domain "+ splitDomain.getIndices().size());
- System.out.println("Numof classes in domain "+ splitDomain.getValues().size());
- System.out.println("Numof splits in domain "+ splitDomain.getIndices().size());
-
- System.out.println("Numof splits in indices "+ split_indices.size());
+ System.out.println("Numof splits in indices "+ split_indices.size());
+ }
FactAttrDistribution facts_at_attribute = new FactAttrDistribution(splitDomain, targetDomain);
facts_at_attribute.setTotal(facts.size());
@@ -362,14 +383,14 @@
//double sum_attr = 0.0;
if (total_num_attr > 0) {
double prob = (double) total_num_attr / (double) fact_size;
- System.out.print("{("+total_num_attr +"/"+fact_size +":"+prob +")* [");
+ if (Util.DEBUG) System.out.print("{("+total_num_attr +"/"+fact_size +":"+prob +")* [");
double info = calc_info(facts_of_attribute.getAttrFor(attr));
sum += prob * info;
- System.out.print("]} ");
+ if (Util.DEBUG) System.out.print("]} ");
}
}
- System.out.println("\n == "+sum);
+ if (Util.DEBUG) System.out.println("\n == "+sum);
return sum;
}
@@ -389,17 +410,15 @@
String out =" ";
for (Object key : targetValues) {
int num_in_class = facts_in_class.getVoteFor(key);
- // System.out.println("num_in_class : "+ num_in_class + " key "+ key+ " and the total num "+ total_num_facts);
-
+
if (num_in_class > 0) {
prob = (double) num_in_class / (double) total_num_facts;
/* TODO what if it is a sooo small number ???? */
- out += "("+num_in_class+ "/"+total_num_facts+":"+prob+")" +"*"+ Util.log2(prob) + " + ";
+ if (Util.DEBUG) out += "("+num_in_class+ "/"+total_num_facts+":"+prob+")" +"*"+ Util.log2(prob) + " + ";
sum -= prob * Util.log2(prob);
- // System.out.println("prob "+ prob +" and the plog(p)"+plog2p+"where the sum: "+sum);
}
}
- System.out.print(out +"= " +sum);
+ if (Util.DEBUG) System.out.print(out +"= " +sum);
return sum;
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/IDTreeBuilder.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -34,14 +34,14 @@
TreeNode result = null;
@Override
public void run() {
- result = builder.id3(dt, facts, attributeNames);
+ result = builder.train(dt, facts, attributeNames);
currentNode.addNode(value, result);
}
}
MyThread helper;
private int FUNC_CALL = 0;
- private int num_fact_processed = 0;
+ private int num_fact_trained = 0;
/*
* treebuilder.execute(workingmemory, classtoexecute, attributestoprocess)
@@ -74,7 +74,7 @@
}
dt.FACTS_READ += facts.size();
- setNum_fact_processed(facts.size());
+ setNum_fact_trained(facts.size());
if (workingAttributes != null)
for (String attr: workingAttributes) {
@@ -89,7 +89,7 @@
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
- TreeNode root = id3(dt, facts, attrs);
+ TreeNode root = train(dt, facts, attrs);
dt.setRoot(root);
return dt;
@@ -114,7 +114,7 @@
}
}
dt.FACTS_READ += facts.size();
- setNum_fact_processed(facts.size());
+ setNum_fact_trained(facts.size());
if (workingAttributes != null)
for (String attr: workingAttributes) {
@@ -130,14 +130,14 @@
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
- TreeNode root = id3(dt, facts, attrs);
+ TreeNode root = train(dt, facts, attrs);
dt.setRoot(root);
return dt;
}
//*OPT* private TreeNode decisionTreeLearning(List<FactSet> facts,
//*OPT* List<String> attributeNames) {
- private TreeNode id3(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
+ public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
FUNC_CALL ++;
if (facts.size() == 0) {
@@ -154,7 +154,7 @@
if (stats.getNum_supported_target_classes() == 1) {
//*OPT* return new LeafNode(facts.get(0).getFact(0).getFieldValue(target));
LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
- classifiedNode.setRank((double)facts.size()/(double)num_fact_processed);
+ classifiedNode.setRank((double)facts.size()/(double)num_fact_trained);
classifiedNode.setNumSupporter(facts.size());
return classifiedNode;
}
@@ -164,7 +164,7 @@
/* an heuristic of the leaf classification*/
Object winner = stats.getThe_winner_target_class();
LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
- noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_processed);
+ noAttributeLeftNode.setRank((double)stats.getVoteFor(winner)/(double)num_fact_trained);
noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
return noAttributeLeftNode;
}
@@ -202,7 +202,7 @@
majorityNode.setNumSupporter(filtered_facts.get(value).size());
currentNode.addNode(value, majorityNode);
} else {
- TreeNode newNode = id3(dt, filtered_facts.get(value), attributeNames_copy);
+ TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
currentNode.addNode(value, newNode);
}
}
@@ -215,10 +215,16 @@
}
- public int getNum_fact_processed() {
- return num_fact_processed;
+ public int getNum_fact_trained() {
+ return num_fact_trained;
}
- public void setNum_fact_processed(int num_fact_processed) {
- this.num_fact_processed = num_fact_processed;
+ public void setNum_fact_trained(int num_fact_processed) {
+ this.num_fact_trained = num_fact_processed;
}
+
+
+ public List<Integer> test(DecisionTree dt, List<Fact> facts) {
+ // TODO Auto-generated method stub
+ return null;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -1,6 +1,7 @@
package dt.memory;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@@ -46,6 +47,10 @@
return;
}
+ public Boolean getClass(Object value) {
+ return (Boolean)value;
+ }
+
public List<Boolean> getValues() {
return fValues;
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -11,18 +11,20 @@
boolean isDiscrete();
void setDiscrete(boolean disc);
- boolean contains(T value);
-
String getName();
void addValue(T value);
void addPseudoValue(T fieldValue);
+ T getClass(Object value);
+
List<T> getValues();
Object readString(String data);
String toString();
+
+ boolean contains(T value) throws Exception;
boolean isPossible(Object value) throws Exception;
void setReadingSeq(int readingSeq);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/DomainFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/DomainFactory.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/DomainFactory.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -6,8 +6,8 @@
}
- public static NumericDomain createNumericDomain(String name) {
- return new NumericDomain(name);
+ public static NumericDomain createNumericDomain(String name, boolean discrete) {
+ return new NumericDomain(name, discrete);
}
public static LiteralDomain createLiteralDomain(String name) {
@@ -19,20 +19,21 @@
if (c.getName().equalsIgnoreCase("boolean")) {
System.out.println("Yuuuupiii boolean");
return createBooleanDomain(domainName);
- } else if (c.getName().equalsIgnoreCase("int") ||
- c.getName().equalsIgnoreCase("double") ||
- c.getName().equalsIgnoreCase("float")) {
+ } else if (c.getName().equalsIgnoreCase("int")) {
+ return createNumericDomain(domainName, true);
+ }else if (c.getName().equalsIgnoreCase("double") || c.getName().equalsIgnoreCase("float")) {
System.out.println("Yuuuupiii number");
- return createNumericDomain(domainName);
+ return createNumericDomain(domainName, false);
} else
return createComplexDomain(c,"kicimi ye simple: "+domainName);
else if (c.isAssignableFrom(String.class)) {
System.out.println("Yuuuupiii string");
return createLiteralDomain(domainName);
- } else if (c.isAssignableFrom(Integer.class) ||
- c.isAssignableFrom(Double.class) ||
- c.isAssignableFrom(Float.class)) {
- return createNumericDomain(domainName);
+ } else if (c.isAssignableFrom(Integer.class)) {
+ return createNumericDomain(domainName, true);
+ }
+ else if (c.isAssignableFrom(Double.class) || c.isAssignableFrom(Float.class)) {
+ return createNumericDomain(domainName, false);
} else if (c.isAssignableFrom(Boolean.class))
return createBooleanDomain(domainName);
else
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSetFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSetFactory.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSetFactory.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -204,23 +204,23 @@
wm.insert(line, klass, separator, domains);
}
}
-
+/*
public static boolean readObjectData(WorkingMemory simple, String filename,
String separator, Object nullObj) {
- /*
- * | class values
- *
- * unacc, acc, good, vgood
- * | attributes
- *
- * buying: vhigh, high, med, low.
- * maint: vhigh, high, med, low.
- * doors: 2, 3, 4, 5, more.
- * persons: 2, 4, more.
- * lug_boot: small, med, big.
- * safety: low, med, high.
- *
- */
+// *
+// * | class values
+// *
+// * unacc, acc, good, vgood
+// * | attributes
+// *
+// * buying: vhigh, high, med, low.
+// * maint: vhigh, high, med, low.
+// * doors: 2, 3, 4, 5, more.
+// * persons: 2, 4, more.
+// * lug_boot: small, med, big.
+// * safety: low, med, high.
+// *
+// *
// String[] attr_order = {"buying", "maint", "doors", "persons", "lug_boot", "safety"
// String filename = "../data/car/car.data.txt";
// String separator = ",";
@@ -239,11 +239,12 @@
return false;
}
+*/
- public static List<Object> fromFileAsObject(WorkingMemory wm, Class<?> klass, String filename, String separator)
+ public static List<Object> fromFileAsObject(WorkingMemory wm, Class<?> klass, String filename, String separator, boolean all_discrete)
throws IOException {
List<Object> obj_read = new ArrayList<Object>();
- OOFactSet fs = wm.getFactSet(klass);
+ OOFactSet fs = wm.getFactSet(klass, all_discrete);
Collection<Domain<?>> domains = fs.getDomains();
File file =new File(filename);
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -2,6 +2,7 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@@ -81,6 +82,33 @@
}
return false;
}
+
+ public String getClass(Object value) {
+ if (discrete) {
+ return (String)value;
+ } else {
+ String str_value = (String)value;
+
+
+ /*
+ * index of the search key, if it is contained in the list; otherwise, (-(insertion point) - 1).
+ * The insertion point is defined as the point at which the key would be inserted into the list:
+ * the index of the first element greater than the key, or list.size(), if all elements in the
+ * list are less than the specified key. Note that this guarantees that the return value will be >= 0
+ * if and only if the key is found.
+ */
+ /*
+ int insertion_point = Collections.binarySearch(fValues, str_value, sComparator);
+ if (insertion_point >= 0) {
+ return fValues.get(insertion_point);
+ } else {
+ return fValues.get(-(insertion_point));
+ }
+ */
+ return str_value;
+ }
+
+ }
public List<String> getValues() {
return fValues;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -25,6 +25,15 @@
nComparator = new NumberComparator();
readingSeq = -1;
}
+
+ public NumericDomain(String _name, boolean discrete) {
+ fName = _name.trim();
+ fValues = new ArrayList<Number>();
+ this.discrete = discrete;
+ fComparator = new FactNumericAttributeComparator(_name);
+ nComparator = new NumberComparator();
+ readingSeq = -1;
+ }
public Domain<Number> clone() {
NumericDomain dom = new NumericDomain(fName);
@@ -68,17 +77,51 @@
if (!fValues.contains(value))
fValues.add(value);
+ /* you can add to the correct position = using binary search */
Collections.sort(fValues, nComparator);
}
}
+
+ public Number getClass(Object value) {
+ if (discrete) {
+ return (Number)value;
+ } else {
+ Number num_value = (Number)value;
+
+ int insertion_point = Collections.binarySearch(fValues, num_value, nComparator);
+ /*
+ * index of the search key, if it is contained in the list; otherwise, (-(insertion point) - 1).
+ * The insertion point is defined as the point at which the key would be inserted into the list:
+ * the index of the first element greater than the key, or list.size(), if all elements in the
+ * list are less than the specified key. Note that this guarantees that the return value will be >= 0
+ * if and only if the key is found.
+ */
+ if (insertion_point >= 0) {
+ return fValues.get(insertion_point);
+ } else {
+ return fValues.get(-(insertion_point) -1);
+ }
+ }
+
+ }
- public boolean contains(Number value) {
- for(Number n: fValues) {
- if (value.intValue() == n.intValue() ||
- value.doubleValue() == n.doubleValue() ||
- value.floatValue() == n.floatValue())
- return true;
+ public boolean contains(Number value) throws Exception {
+ if (discrete) {
+ for(Number n: fValues) {
+ if (nComparator.compare(n, value) == 0)
+ return true;
+ }
+ } else {
+ if (fValues.isEmpty() || fValues.size()==1)
+ throw new Exception("Numerical domain "+fName+" is constant and not discrete but bounds are not set: possible values size: "+ fValues.size());
+
+ // they must be sorted
+ return (nComparator.compare((Number)value, fValues.get(0)) >= 0 && nComparator.compare((Number)value, fValues.get(fValues.size()-1)) <= 0);
+
+ /* should i check if the value is in one of the intervals
+ * this is necessary only if the intervals are unbroken
+ */
}
return false;
}
@@ -126,32 +169,11 @@
return false;
//System.exit(0);
if (constant) {
- //System.out.println("NumericDomain.isPossible() constant "+ value+ " ?");
- //System.exit(0);
-
- if (discrete) {
- if (fValues.contains(value))
- return true;
-
- //System.out.println("NumericDomain.isPossible() constant && discrete "+ value+ " ?");
- //System.exit(0);
- } else {
- if (fValues.isEmpty() || fValues.size()==1)
- throw new Exception("Numerical domain "+fName+" is constant and not discrete but bounds are not set: possible values size: "+ fValues.size());
- if (((Number)value).doubleValue() >= fValues.get(0).doubleValue() &&
- ((Number)value).doubleValue() <= fValues.get(1).doubleValue()) {
- return true;
- }
- //System.out.println("NumericDomain.isPossible() "+ value+ " ?");
- }
+ return this.contains((Number)value);
} else {
return true;
}
- //System.out.println("NumericDomain.isPossible() end "+ value+ " ?");
- //System.exit(0);
-
- return false;
}
public void setReadingSeq(int readingSeq) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -20,14 +20,14 @@
domainset = new Hashtable<String, Domain<?>>();
}
- public OOFactSet getFactSet(Class<?> klass) {
+ public OOFactSet getFactSet(Class<?> klass, boolean all_discrete) {
String element_class = klass.getName();
//System.out.println("Get the keys:"+ factsets.keys());
//System.out.println("WorkingMemory.get class "+ element_class + " exist? "+ factsets.containsKey(element_class));
OOFactSet fs;
if (!factsets.containsKey(element_class))
- fs = create_factset(klass);
+ fs = create_factset(klass, all_discrete);
else
fs = (OOFactSet) factsets.get(element_class);//TODO should i cast
@@ -35,14 +35,14 @@
return fs;
}
- public void insert(Object element) {
+ public void insert(Object element, boolean only_discrete) {
String element_class = element.getClass().getName();
//System.out.println("Get the keys:"+ factsets.keys());
//System.out.println("WorkingMemory.get class "+ element_class + " exist? "+ factsets.containsKey(element_class));
OOFactSet fs;
if (!factsets.containsKey(element_class))
- fs = create_factset(element.getClass());
+ fs = create_factset(element.getClass(), only_discrete);
else
fs = (OOFactSet) factsets.get(element_class);//TODO should i cast
@@ -84,7 +84,7 @@
* we said that the domains should be independent from the factset
*/
- private OOFactSet create_factset(Class<?> classObj) {
+ private OOFactSet create_factset(Class<?> classObj, boolean all_discrete) {
//System.out.println("WorkingMemory.create_factset element "+ element );
OOFactSet newfs = new OOFactSet(classObj);
@@ -96,8 +96,14 @@
System.out.println("WHat is this f: " +f.getType()+" the name "+f_name+" class "+ f.getClass() + " and the name"+ f.getClass().getName());
if (Util.isSimpleType(f_class)) {
- Annotation[] annotations = f.getAnnotations();
+ Domain<?> fieldDomain;
+ if (!domainset.containsKey(f_name)) {
+ fieldDomain = DomainFactory.createDomainFromClass(f.getType(), f_name);
+ domainset.put(f_name, fieldDomain);
+ } else
+ fieldDomain = domainset.get(f_name);
+ Annotation[] annotations = f.getAnnotations();
// iterate over the annotations to locate the MaxLength constraint if it exists
DomainSpec spec = null;
for (Annotation a : annotations) {
@@ -107,18 +113,19 @@
}
}
- Domain<?> fieldDomain;
- if (!domainset.containsKey(f_name)) {
- fieldDomain = DomainFactory.createDomainFromClass(f.getType(), f_name);
- domainset.put(f_name, fieldDomain);
- } else
- fieldDomain = domainset.get(f_name);
-
- //System.out.println("WorkingMemory.create_factset field "+ field + " fielddomain name "+fieldDomain.getName()+" return_type_name: "+return_type_name+".");
if (spec != null) {
fieldDomain.setReadingSeq(spec.readingSeq());
- fieldDomain.setDiscrete(spec.discrete());
+ if (!all_discrete) fieldDomain.setDiscrete(spec.discrete());
}
+ /*
+ * ID3 would
+ * if it is integer and the annotation saying that the field is continuous
+ * ignore the domain
+ * if it is double / float and the annotation saying that the field is continuous
+ * ignore the domain if it has more than 10 values ?
+ * if it is string and the annotation saying that the field is continuous
+ * what to do??
+ */
newfs.addDomain(f_name, fieldDomain);
@@ -127,7 +134,7 @@
factsets.put(classObj.getName(), newfs);
return newfs;
}
-
+
private OOFactSet create_factset_(Class<?> classObj) {
//System.out.println("WorkingMemory.create_factset element "+ element );
@@ -170,6 +177,10 @@
return newfs;
}
+// public Iterator<Domain<?>> getDomains() {
+// return domainset.values().iterator();
+// }
+
/* TODO: is there a better way of doing this iterator? */
public Iterator<FactSet> getFactsets() {
return factsets.values().iterator();
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -12,7 +12,8 @@
public static List<Object> processFileExmID3(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator) {
try {
- List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
+ boolean only_discrete = true;
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, only_discrete);
IDTreeBuilder bocuk = new IDTreeBuilder();
long dt = System.currentTimeMillis();
@@ -22,9 +23,11 @@
dt = System.currentTimeMillis() - dt;
System.out.println("Time" + dt + "\n" + bocuksTree);
- RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
boolean sort_via_rank = true;
- my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
+ boolean print = true;
+ my_printer.printer(bocuksTree, sort_via_rank, print);
+ my_printer.write2file("examples", "src/rules/examples/" + drlfile);
return obj_read;
@@ -36,11 +39,13 @@
}
+
- public static List<Object> processFileExmC45(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator) {
+ public static List<Object> processFileExmC45(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator, int max_rules) {
try {
- List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
+ boolean all_discrete = false;
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
C45TreeBuilder bocuk = new C45TreeBuilder();
long dt = System.currentTimeMillis();
@@ -48,13 +53,17 @@
List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
- DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, workingAttributes);
+ DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass(), target_attr, workingAttributes);
dt = System.currentTimeMillis() - dt;
System.out.println("Time" + dt + "\n" + bocuksTree);
- RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
+ if (max_rules >0)
+ my_printer.setMax_num_rules(max_rules);
boolean sort_via_rank = true;
- my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
+ boolean print = true;
+ my_printer.printer(bocuksTree, sort_via_rank, print);
+ my_printer.write2file("examples", "src/rules/examples/" + drlfile);
return obj_read;
@@ -67,25 +76,31 @@
}
- public static List<Object> processFileExmC45(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator, int max_rules) {
+ public static List<Object> test_dt_training(WorkingMemory simple, Object emptyObject, String drlfile, String datafile, String separator, int max_rules) {
try {
- List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator);
- C45TreeBuilder bocuk = new C45TreeBuilder();
+ boolean all_discrete = false;
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
+
long dt = System.currentTimeMillis();
String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
- DecisionTree bocuksTree = bocuk.build(simple, emptyObject.getClass().getName(), target_attr, workingAttributes);
+ C45TreeBuilder bocuk = new C45TreeBuilder(simple);
+ DecisionTree bocuksTree = bocuk.build(emptyObject.getClass(), target_attr, workingAttributes);
dt = System.currentTimeMillis() - dt;
- System.out.println("Time" + dt + "\n" + bocuksTree);
-
- RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed(), max_rules);
- boolean sort_via_rank = true;
- my_printer.printer(bocuksTree, "examples", "src/rules/examples/"+drlfile, sort_via_rank);
-
+// System.out.println("Time" + dt + "\n" + bocuksTree);
+//
+// RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
+// if (max_rules >0)
+// my_printer.setMax_num_rules(max_rules);
+// boolean sort_via_rank = true;
+// boolean print = true;
+// my_printer.printer(bocuksTree, sort_via_rank, print);
+// my_printer.write2file("examples", "src/rules/examples/" + drlfile);
+//
return obj_read;
} catch (Exception e) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/ObjectReader.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -22,6 +22,7 @@
private static final boolean DEBUG = false;
+
public static Object read(Class<?> element_class,
Collection<Domain<?>> domains, String data, String separator) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/RulePrinter.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -26,7 +26,7 @@
private boolean ONLY_ACTIVE = true;
private int num_facts;
//private RuleComparator rule_comp = new RuleComparator();
- private int max_num_rules;
+ private int max_num_rules = -1;
public RulePrinter(int num_facts) {
@@ -40,18 +40,6 @@
this.num_facts = num_facts;
}
-
- public RulePrinter(int num_facts, int max_num_rules) {
- ruleText = new ArrayList<String>();
- //rule_list = new ArrayList<ArrayList<NodeValue>>();
- rules = new ArrayList<Rule>();
-
- /* most important */
- nodes = new Stack<NodeValue>();
-
- this.num_facts = num_facts;
- this.max_num_rules = max_num_rules;
- }
public int getNum_facts() {
return num_facts;
}
@@ -60,10 +48,35 @@
this.num_facts = num_facts;
}
- public void printer(DecisionTree dt, String packageName, String outputFile, boolean sort) {//, PrintStream object
+ public int getMax_num_rules() {
+ return max_num_rules;
+ }
+
+ public void setMax_num_rules(int max_num_rules) {
+ this.max_num_rules = max_num_rules;
+ }
+
+ public void printer(DecisionTree dt, boolean sort, boolean print) {//, PrintStream object
ruleObject = dt.getName();
dfs(dt.getRoot());
+
+ if (sort)
+ Collections.sort(rules, Rule.getRankComparator());
+ if (print)
+ System.out.println(printRules());
+ }
+ public String printRules() {
+ String out = "";
+ int i = 0;
+ for( Rule rule: rules) {
+ i++;
+ out += ("Rule " +i + " rank("+rule.getRank()+")"+" suggests that \n"+ rule.toPrint() +".\n");
+ }
+ return out;
+ }
+
+ public void write2file(String packageName, String outputFile) {
if (outputFile!=null) {
if (packageName != null)
write("package " + packageName +";\n\n", false, outputFile);
@@ -76,9 +89,7 @@
}
}
- if (sort)
- Collections.sort(rules, Rule.getRankComparator());
-
+ System.out.println("//Num of rules " +rules.size()+"\n");
int total_num_facts=0;
int i = 0, active_i = 0;
for( Rule rule: rules) {
@@ -104,7 +115,7 @@
}
}
total_num_facts += rule.getPopularity();
- if (i == getMax_num_rules())
+ if (getMax_num_rules()>0 && i >= getMax_num_rules())
break;
}
if (outputFile!=null) {
@@ -135,19 +146,15 @@
return;
}
- Hashtable<Object,TreeNode> children = my_node.getChildren();
- for (Object attributeValue : children.keySet()) {
+ //Hashtable<Object,TreeNode> children = my_node.getChildrenKeys();
+ for (Object attributeValue : my_node.getChildrenKeys()) {
//System.out.println("Domain: "+ my_node.getDomain().getName() + " the value:"+ attributeValue);
node_value.setNodeValue(attributeValue);
- TreeNode child = children.get(attributeValue);
+ TreeNode child = my_node.getChild(attributeValue);
dfs(child);
nodes.pop();
}
return;
-
-
-
-
}
private Rule spitRule(Stack<NodeValue> nodes) {
@@ -232,16 +239,6 @@
}
}
}
-
-
- public int getMax_num_rules() {
- return max_num_rules;
- }
-
-
- public void setMax_num_rules(int max_num_rules) {
- this.max_num_rules = max_num_rules;
- }
}
class Rule {
@@ -256,7 +253,9 @@
conditions = new ArrayList<NodeValue>(numCond);
actions = new ArrayList<NodeValue>(1);
}
-
+ public void setRank(double r) {
+ this.rank = r;
+ }
public double getRank() {
return rank;
}
@@ -266,7 +265,7 @@
}
public void addAction(NodeValue current) {
actions.add(new NodeValue(current.getNode(), current.getNodeValue()));
- rank = ((LeafNode)current.getNode()).getRank();
+ this.setRank(((LeafNode)current.getNode()).getRank());
popularity = ((LeafNode)current.getNode()).getNum_facts_classified();
}
public void setObject(String obj) {
@@ -294,6 +293,22 @@
this.popularity = popularity;
}
+ public String toPrint() {
+
+ String out = "if ";
+ for (NodeValue cond: conditions) {
+ out += cond + " & ";
+ }
+ out = out.substring(0, out.length()-2);
+
+ String action = "";
+ for (NodeValue act: actions) {
+ action += act.getNodeValue() + " & ";
+ }
+ action = action.substring(0, action.length()-3);
+ out += "then DECISION("+action +")";
+ return out;
+ }
public String toString() {
/*
@@ -395,10 +410,15 @@
else {
int size = node.getDomain().getValues().size();
//System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
- if (node.getDomain().getValues().lastIndexOf(nodeValue) == size-1)
+ if (node.getDomain().getValues().lastIndexOf(nodeValue) == 0)
+ return node.getDomain() + " <= "+ value;
+ else if (node.getDomain().getValues().lastIndexOf(nodeValue) == size-1)
return node.getDomain() + " > "+ node.getDomain().getValues().get(size-2);
- else
- return node.getDomain() + " <= "+ value;
+ else {
+ //find the one before him
+ int current = node.getDomain().getValues().lastIndexOf(nodeValue);
+ return node.getDomain().getValues().get(current-1) +" < "+node.getDomain() + " <= "+ value;
+ }
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -7,7 +7,8 @@
public class Util {
- public static boolean DEBUG = true;
+ public static boolean DEBUG = false;
+ public static boolean DEBUG_TEST = true;
public static String ntimes(String s,int n){
StringBuffer buf = new StringBuffer();
@@ -80,7 +81,7 @@
public static String sum() {
return "sum";
}
-
+
public static void insert(List<Object> list, Object key, Comparator<Object> c) {
int insertion_point_1 = Collections.binarySearch(list, key, c);
if (insertion_point_1 <0)
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukFileExample.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -43,9 +43,11 @@
System.out.println("Time"+dt + " facts read: "+bocuksTree.getNumRead() + " num call: "+ bocuk.getNumCall() );
//System.out.println(bocuksTree);
- RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
boolean sort_via_rank = true;
- my_printer.printer(bocuksTree, null, null, sort_via_rank);
+ boolean print = true;
+ my_printer.printer(bocuksTree, sort_via_rank, print);
+ my_printer.write2file(null, null);
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/BocukObjectExample.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -27,7 +27,8 @@
facts.add(new RestaurantOld(true, true, true, true, "Full", 1, false, false, "Burger", "30-60", true));
WorkingMemory simple = new WorkingMemory();
- OOFactSet fs = simple.getFactSet(arest.getClass());
+ boolean only_discrete = true;
+ OOFactSet fs = simple.getFactSet(arest.getClass(), only_discrete);
for(Object r: facts) {
try {
@@ -45,8 +46,10 @@
dt = System.currentTimeMillis() - dt;
System.out.println("Time"+dt+"\n"+bocuksTree);
- RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_processed());
+ RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
boolean sort_via_rank = true;
- my_printer.printer(bocuksTree,"test" , new String("../dt_learning/src/test/rules"+".drl"), sort_via_rank);
+ boolean print = true;
+ my_printer.printer(bocuksTree, sort_via_rank, print);
+ my_printer.write2file("test" , new String("../dt_learning/src/test/rules"+".drl"));
}
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java 2008-04-16 17:53:59 UTC (rev 19595)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/test/RestaurantOld.java 2008-04-16 18:16:24 UTC (rev 19596)
@@ -141,6 +141,10 @@
this.will_wait = will_wait;
}
+ public void p() {
+ System.out.println(this.hashCode());
+ }
+
}
More information about the jboss-svn-commits
mailing list