[jboss-svn-commits] JBL Code SVN: r19651 - in labs/jbossrules/contrib/machinelearning/decisiontree/src/dt: builder and 2 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Sun Apr 20 10:11:36 EDT 2008
Author: gizil
Date: 2008-04-20 10:11:35 -0400 (Sun, 20 Apr 2008)
New Revision: 19651
Added:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java
Modified:
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
Log:
serializing decision tree and tree builder + the dumbest way of retraining a decision tree(= iterationg over an existing tree)
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,15 +1,22 @@
package dt;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
import dt.memory.Domain;
import dt.memory.Fact;
-import dt.tools.Util;
-public class DecisionTree {
+public class DecisionTree implements Serializable{
+
+
+ /**
+ *
+ */
+ private static final long serialVersionUID = 1L;
+
public long FACTS_READ = 0;
/* set of the attributes, their types */
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -12,13 +12,34 @@
private Object targetValue;
private double rank;
private int num_facts_classified;
+
+ private Fact pseudo_f;
public LeafNode(Domain<?> targetDomain, Object value){
super(targetDomain);
this.targetValue = value;
num_facts_classified = 0;
+
+ this.pseudo_f = new Fact();
+ this.setPseudoFact();
}
+ public void setTargetValue(Object value) {
+ this.targetValue = value;
+ this.pseudo_f = new Fact();
+ this.setPseudoFact();
+ }
+ public void setPseudoFact() {
+ try {
+ pseudo_f.add(this.getDomain(), this.getValue());
+ } catch (Exception e) {
+ System.out.println(Util.ntimes("\n", 10)+"Unknown situation at leafnode: " + this.getValue() + " @ "+ this.getDomain());
+ e.printStackTrace();
+ // Unknown
+ System.exit(0);
+
+ }
+ }
public void addNode(Object attributeValue, TreeNode node) {
throw new RuntimeException("cannot add Node to a leaf node");
}
@@ -42,26 +63,14 @@
public Integer evaluate(Fact f) {
Domain<?> target_domain = this.getDomain();
- Fact pseudo_f = new Fact();
- try {
- pseudo_f.add(target_domain, this.getValue());
- Comparator<Fact> targetComp = target_domain.factComparator();
- if (targetComp.compare(f, pseudo_f) == 0 ) {
- return Integer.valueOf(1);
- } else {
- return Integer.valueOf(0);
- }
- } catch (Exception e) {
-
- System.out.println(Util.ntimes("\n", 10)+"Unknown situation at leafnode: " + this.getValue() + " @ "+ target_domain);
- e.printStackTrace();
- // Unknown
- System.exit(0);
- return Integer.valueOf(2);
- }
+ Comparator<Fact> targetComp = target_domain.factComparator();
+ if (targetComp.compare(f, this.pseudo_f) == 0 ) {
+ return Integer.valueOf(1); //correct
+ } else {
+ return Integer.valueOf(0); // mistake
+ }
-
}
public String toString(){
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,4 +1,5 @@
package dt;
+import java.io.Serializable;
import java.util.Collection;
import java.util.Hashtable;
@@ -7,7 +8,7 @@
import dt.tools.Util;
-public class TreeNode {
+public class TreeNode implements Serializable{
private Domain<?> domain;
private Hashtable<Object, TreeNode> children;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,8 @@
package dt.builder;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
@@ -10,14 +10,12 @@
import dt.DecisionTree;
import dt.LeafNode;
import dt.TreeNode;
-
+import dt.memory.Domain;
+import dt.memory.Fact;
import dt.memory.FactDistribution;
-import dt.memory.FactTargetDistribution;
-import dt.memory.WorkingMemory;
-import dt.memory.Fact;
import dt.memory.FactSet;
import dt.memory.OOFactSet;
-import dt.memory.Domain;
+import dt.memory.WorkingMemory;
import dt.tools.FactProcessor;
import dt.tools.Util;
@@ -42,10 +40,14 @@
MyThread helper;
private int FUNC_CALL = 0;
protected int num_fact_trained = 0;
+ private ArrayList<Fact> facts;
+ private ArrayList<Fact> training_facts;
private ArrayList<Fact> unclassified_facts;
- private ArrayList<Fact> training_facts;
+
private WorkingMemory global_wm;
private List<Domain<?>> domains;
+ private String target;
+ private List<String> attributes;
/*
* treebuilder.execute(workingmemory, classtoexecute, attributestoprocess)
@@ -56,23 +58,43 @@
* internalprocess(attributestoprocess)
*/
public C45TreeBuilder(WorkingMemory wm) {
+
+ global_wm = wm;
+ facts = new ArrayList<Fact>();
+ training_facts = new ArrayList<Fact>();
unclassified_facts = new ArrayList<Fact>();
- training_facts = new ArrayList<Fact>();
- global_wm = wm;
+
+ target = null;
+ attributes = new ArrayList<String>();
domains = new ArrayList<Domain<?>>();
-
}
public C45TreeBuilder() {
+ facts = new ArrayList<Fact>();
+ training_facts = new ArrayList<Fact>();
unclassified_facts = new ArrayList<Fact>();
- training_facts = new ArrayList<Fact>();
+
+ target = null;
+ attributes = new ArrayList<String>();
domains = new ArrayList<Domain<?>>();
-
+
}
+ /* set the builder's
+ * domains
+ */
+ public void setDomains(Class<?> klass) {
+ FactSet klass_fs = null;
+
+ for (Domain<?> d : klass_fs.getDomains())
+ domains.add(d);
+ }
-
+ /* set the builder's
+ * facts
+ * domains
+ */
private void setKlass(Class<?> klass) {
Iterator<FactSet> it_fs = global_wm.getFactsets();
FactSet klass_fs = null;
@@ -81,10 +103,10 @@
if (fs instanceof OOFactSet) {
if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
// **OPT facts.add(fs);
- fs.assignTo(training_facts); // adding all facts of fs to "facts
+ fs.assignTo(facts); // adding all facts of fs to "facts
}
} else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
- fs.assignTo(training_facts); // adding all facts of fs to "facts"
+ fs.assignTo(facts); // adding all facts of fs to "facts"
klass_fs = fs;
break;
@@ -98,8 +120,47 @@
domains.add(d);
}
- private void init(DecisionTree dt, String targetField, List<String> workingAttributes) {
+ /* initialize the builder's
+ * targetField
+ * the attribute list (workingAttributes != null ? workingAttributes : domains )
+ */
+ public void init(String targetField, List<String> workingAttributes) {
+ this.setTarget(targetField);
+ if (workingAttributes != null)
+ for (String attr : workingAttributes) {
+ this.addAttribute(attr);
+ }
+ else {
+ for (Domain<?> d : domains) {
+ this.addAttribute(d.getName());
+ }
+ }
+
+ }
+
+ public void setTarget(String targetField) {
+ this.target = targetField;
+ //attrsToClassify.remove(target);
+ }
+
+ public void addDomain(Domain<?> d) {
+ //if (!attribute.equals(this.target))
+ //attributes.add(d.getName());
+ domains.add(d);
+ }
+ public void addAttribute(String attribute) {
+ //if (!attribute.equals(this.target))
+ attributes.add(attribute);
+ }
+ private void init_dt(DecisionTree dt, String targetField) {
dt.setTarget(targetField);
+ for (Domain<?> d : domains) {
+ dt.addDomain(d);
+ }
+
+ }
+ private void init_dt(DecisionTree dt, String targetField, List<String> workingAttributes) {
+ dt.setTarget(targetField);
if (workingAttributes != null)
for (String attr : workingAttributes) {
dt.addDomain(global_wm.getDomain(attr));
@@ -111,23 +172,18 @@
}
}
-
- /* building with a training and test */
- public DecisionTree build(Class<?> klass, String targetField, List<String> workingAttributes) {
+ /* building with the training set (all relative facts from wm) from scratch*/
+ public DecisionTree build(Class<?> klass) {
/* gets the facts the decision tree is eligible */
setKlass(klass);
DecisionTree dt = new DecisionTree(klass.getName());
- init(dt, targetField, workingAttributes);
+ init_dt(dt, this.target, this.attributes);
-
- DecisionTree best_dt = new DecisionTree(klass.getName());
- init(dt, targetField, workingAttributes);
-
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
-
+ training_facts.addAll(facts);
dt.FACTS_READ += training_facts.size();
/* you must set this when the training called the first time */
setNum_fact_trained(training_facts.size());
@@ -135,20 +191,29 @@
//while ()
TreeNode root = train(dt, training_facts, attrs);
dt.setRoot(root);
+
+ return dt;
+ }
+
+
+ /* building with a training and test */
+ public DecisionTree build_test(Class<?> klass, String targetField, List<String> workingAttributes) {
+ if (this.target == null) {
+ System.out.println("Target is not set");
+ System.exit(0);
+ }
+ /* gets the facts the decision tree is eligible */
+ DecisionTree tree = build(klass);
+
System.out.println(Util.ntimes("\n", 2)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 2));
- List<Integer> evaluation = test(dt, training_facts.subList(339, 340));
+ List<Integer> evaluation = test(tree, training_facts);//.subList(339, 340));
System.out.println("TESTING results: Mistakes "+ evaluation.get(0));
System.out.println("TESTING results: Corrects "+ evaluation.get(1));
- System.out.println("TESTING results: Unknown "+ evaluation.get(2));
- if (evaluation.get(1) == training_facts.size()) {
- best_dt.setRoot(root);
- }
-
- return dt;
+ System.out.println("TESTING results: Unknown "+ evaluation.get(2) +" OF "+ training_facts.size() + " facts");
+ return tree;
}
-
public DecisionTree build(WorkingMemory wm, Class<?> klass,
String targetField, List<String> workingAttributes) {
@@ -199,50 +264,32 @@
return dt;
}
-/* public DecisionTree build(WorkingMemory wm, String klass,
- String targetField, List<String> workingAttributes) {
- unclassified_facts = new ArrayList<Fact>();
- DecisionTree dt = new DecisionTree(klass);
- // **OPT List<FactSet> facts = new ArrayList<FactSet>();
- ArrayList<Fact> facts = new ArrayList<Fact>();
- FactSet klass_fs = null;
- Iterator<FactSet> it_fs = wm.getFactsets();
- while (it_fs.hasNext()) {
- FactSet fs = it_fs.next();
- if (klass == fs.getClassName()) {
- // **OPT facts.add(fs);
- fs.assignTo(facts); // adding all facts of fs to "facts"
-
- klass_fs = fs;
- break;
- }
- }
- dt.FACTS_READ += facts.size();
- setNum_fact_processed(facts.size());
-
- if (workingAttributes != null)
- for (String attr : workingAttributes) {
- //System.out.println("Bok degil " + attr);
- dt.addDomain(klass_fs.getDomain(attr));
- }
- else
- for (Domain<?> d : klass_fs.getDomains())
- dt.addDomain(d);
-
- dt.setTarget(targetField);
-
+ /* building with the training set (some part of the facts) */
+ public DecisionTree build(Class<?> klass, List<Fact> first_facts) {
+ /* gets the facts which the decision tree is eligible */
+ //setKlass(klass);
+
+ DecisionTree dt = new DecisionTree(klass.getName());
+ init_dt(dt, this.target); // initialize the decision tree with the target and all domains
+
ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
Collections.sort(attrs);
+
+ training_facts.addAll(first_facts);
+ dt.FACTS_READ += first_facts.size();
+ /* you must set this when the training called the first time */
+ setNum_fact_trained(training_facts.size());
- TreeNode root = c45(dt, facts, attrs);
+ //while ()
+ TreeNode root = train(dt, training_facts, attrs);
dt.setRoot(root);
-
+
+
return dt;
- }*/
+ }
+
+ public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
- public TreeNode train(DecisionTree dt, List<Fact> facts,
- List<String> attributeNames) {
-
FUNC_CALL++;
if (facts.size() == 0) {
throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -285,7 +332,7 @@
// String chosenAttribute = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
// List<?> categorization = dt.getPossibleValues(chosenAttribute);
Domain<?> choosenDomain = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
- System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
+ if (Util.RUN) System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
TreeNode currentNode = new TreeNode(choosenDomain);
@@ -321,8 +368,161 @@
return currentNode;
}
+ /* building with the training set (some part of the facts) */
+ public DecisionTree re_build(DecisionTree dt, List<Fact> new_facts) {
+
+ ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
+ Collections.sort(attrs);
+
+ training_facts.addAll(new_facts);
+ dt.FACTS_READ += new_facts.size();
+ /* you must set this when the training called the first time */
+ setNum_fact_trained(training_facts.size());
+ System.out.println(Util.ntimes("\n", 10)+"How facts are u training? "+ training_facts.size());
+ //while ()
+ TreeNode root = re_train(dt, dt.getRoot(), training_facts, attrs);
+ dt.setRoot(root);
+
+ return dt;
+ }
+ public TreeNode re_train(DecisionTree dt, TreeNode currentNode, List<Fact> facts, List<String> attributeNames) {
+
+ FUNC_CALL++;
+ if (facts.size() == 0) {
+ throw new RuntimeException("Nothing to classify, factlist is empty");
+ }
+ /* let's get the statistics of the results */
+ // List<?> targetValues = dt.getPossibleValues(dt.getTarget());
+ //Hashtable<Object, Integer> stats_ = dt.getStatistics(facts, dt.getTarget());// targetValues
+
+ //FactTargetDistribution stats = dt.getDistribution(facts);
+
+ FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+ stats.calculateDistribution(facts);
+ stats.evaluateMajority();
+
+ /* if all elements are classified to the same value */
+ if (stats.getNum_supported_target_classes() == 1) {
+
+ LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+ classifiedNode.setRank((double) facts.size()/(double) getNum_fact_trained());
+ classifiedNode.setNumSupporter(facts.size());
+
+ return classifiedNode;
+ }
+
+ /* if there is no attribute left in order to continue */
+ if (attributeNames.size() == 0) {
+ /* an heuristic of the leaf classification */
+ Object winner = stats.getThe_winner_target_class();
+ LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+ noAttributeLeftNode.setRank((double) stats.getVoteFor(winner)/ (double) num_fact_trained);
+ noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
+
+ /* we need to know how many guys cannot be classified and who these guys are */
+ FactProcessor.splitUnclassifiedFacts(unclassified_facts, stats);
+
+ return noAttributeLeftNode;
+ }
+
+ /* choosing the attribute for the branching starts */
+// String chosenAttribute = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+// List<?> categorization = dt.getPossibleValues(chosenAttribute);
+ Domain<?> choosenDomain = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+ if (Util.RUN) System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
+ else if (FUNC_CALL % 100 ==0){
+ System.out.print(".");
+ }
+
+ Hashtable<Object, List<Fact>> filtered_facts = FactProcessor.splitFacts(facts, choosenDomain);
+ for (Object value : filtered_facts.keySet()) {
+ if (filtered_facts.get(value).isEmpty()){
+ @SuppressWarnings("unused")
+ boolean bok = true;
+ }
+ }
+ dt.FACTS_READ += facts.size();
+
+ if (currentNode.getDomain() == choosenDomain) {
+
+
+ for (Object value : filtered_facts.keySet()) {
+
+ TreeNode childNode = currentNode.getChild(value);
+ /* split the last two class at the same time */
+
+ ArrayList<String> attributeNames_copy = new ArrayList<String>(
+ attributeNames);
+ attributeNames_copy.remove(choosenDomain.getName());
+
+ if (filtered_facts.get(value).isEmpty()) {
+ /* majority !!!! */
+ //Comparator<Fact> targetComp = dt.getDomain(dt.getTarget()).factComparator();
+
+ if (childNode == null || !(childNode instanceof LeafNode)) {
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+ majorityNode.setRank(-1.0); // classifying nothing
+ majorityNode.setNumSupporter(filtered_facts.get(value).size());
+
+ childNode = majorityNode; // How to set this guy
+ if (childNode == null)
+ currentNode.addNode(value, childNode);
+ }
+
+ else {
+ /* have to remove the leafnode from the children list with key value*/
+ ((LeafNode)childNode).setRank(-1.0); // classifying nothing
+ ((LeafNode)childNode).setNumSupporter(filtered_facts.get(value).size());
+
+ if (dt.getDomain(dt.getTarget()).compare(((LeafNode)childNode).getValue(), value)!=0) {
+ ((LeafNode)childNode).setTargetValue(value);
+ //currentNode.
+ }
+ }
+
+
+
+ } else {
+ if (childNode == null) {
+ TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
+ currentNode.addNode(value, newNode);
+ }
+ TreeNode newNode = re_train(dt, childNode, filtered_facts.get(value), attributeNames_copy);
+ //currentNode.addNode(value, newNode);
+ }
+ }
+
+ } else {
+ currentNode = new TreeNode(choosenDomain);
+
+ for (Object value : filtered_facts.keySet()) {
+
+ /* split the last two class at the same time */
+
+ ArrayList<String> attributeNames_copy = new ArrayList<String>(
+ attributeNames);
+ attributeNames_copy.remove(choosenDomain.getName());
+
+ if (filtered_facts.get(value).isEmpty()) {
+ /* majority !!!! */
+ LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+ majorityNode.setRank(-1.0); // classifying nothing
+ majorityNode.setNumSupporter(filtered_facts.get(value).size());
+ currentNode.addNode(value, majorityNode);
+ } else {
+
+ TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
+ currentNode.addNode(value, newNode);
+ }
+ }
+ }
+
+ return currentNode;
+ }
+
+
public List<Integer> test(DecisionTree dt, List<Fact> facts) {
/*
* false | true | unknown
@@ -336,18 +536,39 @@
int i = 0;
for (Fact f : facts) {
+
+ Integer result = dt.test(f);
if (Util.DEBUG_TEST) {
- System.out.println(Util.ntimes("#\n", 5)+i+ " <START> TEST: f="+ f);
- //System.exit(0);
- }
- Integer result = dt.test(f);
-
+ System.out.println(Util.ntimes("#\n", 1)+i+ " <START> TEST: f="+ f + " = target "+ result);
+ } else
+ if (i%1000 ==0) System.out.print(".");
results.set(result, Integer.valueOf(results.get(result) + 1));
i ++;
}
return results;
}
+
+ public List<Fact> getFacts(int fromIndex, int toIndex) {
+ return facts.subList(fromIndex, toIndex); //.iterator();
+ }
+
+ public List<Fact> getFacts() {
+ return facts; //.iterator();
+ }
+
+ public List<Fact> getTrainingFacts() {
+ return training_facts; //.iterator();
+ }
+
+ public List<Fact> getUnClassifiedFacts() {
+ return unclassified_facts; //.iterator();
+ }
+
+
+ public int getNumUnClassifiedFacts() {
+ return unclassified_facts.size(); //.iterator();
+ }
public int getNumCall() {
return FUNC_CALL;
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,5 +1,6 @@
package dt.builder;
+import java.io.Serializable;
import java.util.List;
import dt.DecisionTree;
@@ -7,7 +8,7 @@
import dt.memory.Fact;
import dt.memory.WorkingMemory;
-public interface DecisionTreeBuilder {
+public interface DecisionTreeBuilder extends Serializable{
DecisionTree build(WorkingMemory wm, Class<?> klass, String targetField, List<String> workingAttributes);
@@ -15,7 +16,6 @@
public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames);
public List<Integer> test(DecisionTree dt, List<Fact> facts);
//DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
-
int getNum_fact_trained();
void setNum_fact_trained(int num);
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -34,7 +34,8 @@
// if (attr.equalsIgnoreCase(targetDomain.getName()))
// continue;
- System.out.println("Which attribute to try: "+ attr);
+ if (Util.RUN) System.out.println("Which attribute to try: "+ attr);
+
double gain = 0;
if (dt.getDomain(attr).isDiscrete()) {
/* */
@@ -76,7 +77,9 @@
// if (!bestDomain.isDiscrete())
// bestDomain.setIndices(split_indices);
- System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
+ if (Util.RUN)
+ System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
+
}
if (attr.equalsIgnoreCase("c2"))
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -141,5 +141,17 @@
public void addIndex(int index) {
// TODO Auto-generated method stub
}
+
+ @Override
+ public boolean equals(Object d_obj) {
+ Domain<?>d = (Domain<?>)d_obj;
+ return (this.getName().equals(d.getName()));
+ }
+
+ public int compare(Object v1, Object v2) {
+ Boolean b1 = (Boolean) v1;
+ Boolean b2 = (Boolean) v2;
+ return b1.equals(b2) ? 0 : 1;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,9 +1,10 @@
package dt.memory;
+import java.io.Serializable;
import java.util.Comparator;
import java.util.List;
-public interface Domain<T> {
+public interface Domain<T> extends Serializable {
boolean isConstant();
void setConstant();
@@ -38,6 +39,9 @@
List<Integer> getIndices();
void addIndex(int index);
+
+ int compare(Object o1, Object o2);
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,16 +1,18 @@
package dt.memory;
+import java.io.Serializable;
+import java.util.HashMap;
import java.util.Hashtable;
import java.util.Set;
-public class Fact {
+public class Fact implements Serializable{
private Hashtable<String, Domain<?>> fields;
- private Hashtable<String, Object> values;
+ private HashMap<String, Object> values;
public Fact() {
- this.values = new Hashtable<String, Object>();
+ this.values = new HashMap<String, Object>();
this.fields = new Hashtable<String, Domain<?>>();
/* while creating the fact i should add the possible keys, the valid domains */
}
@@ -19,7 +21,7 @@
this.fields = new Hashtable<String, Domain<?>>();
for (Domain<?> d: domains)
this.fields.put(d.getName(), d);
- this.values = new Hashtable<String, Object>();
+ this.values = new HashMap<String, Object>();
//this.attributes. of the keys are only these domains
/* while creating the fact i should add the possible keys, the valid domains */
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,10 +1,11 @@
package dt.memory;
+import java.io.Serializable;
import java.util.Comparator;
import dt.tools.Util;
-public class FactLiteralAttributeComparator implements Comparator<Fact> {
+public class FactLiteralAttributeComparator implements Comparator<Fact>, Serializable {
private String attr_name;
public FactLiteralAttributeComparator(String _attr_name) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
package dt.memory;
+import java.io.Serializable;
import java.util.Comparator;
-public class FactNumericAttributeComparator implements Comparator<Fact> {
+public class FactNumericAttributeComparator implements Comparator<Fact>, Serializable {
private String attr_name;
public FactNumericAttributeComparator(String _attr_name) {
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
package dt.memory;
+import java.io.Serializable;
import java.util.Collection;
-public interface FactSet {
+public interface FactSet extends Serializable {
String getClassName();
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -168,5 +168,34 @@
public void addIndex(int index) {
// TODO Auto-generated method stub
}
+
+ @Override
+ public boolean equals(Object d_obj) {
+ Domain<?>d = (Domain<?>)d_obj;
+ if (!this.getName().equals(d.getName())) {
+ return false;
+ }
+ else {
+ if (this.discrete) {
+ return (this.fValues.size() == d.getValues().size());
+ } else if (this.fValues.size() != d.getValues().size()) {
+ return false;
+ } else {
+ List<String> dValues = ((LiteralDomain) d).getValues();
+ for (int i = 0 ; i < this.fValues.size() ; i++)
+ if (!this.fValues.get(i).equals(dValues.get(i)))
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+
+ public int compare(Object v1, Object v2) {
+ String s1 = (String) v1;
+ String s2 = (String) v2;
+ return s1.equals(s2) ? 0 : 1;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
package dt.memory;
+import java.io.Serializable;
import java.util.Comparator;
-public class NumberComparator implements Comparator<Number> {
+public class NumberComparator implements Comparator<Number>, Serializable {
public NumberComparator() {
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -190,6 +190,13 @@
String out = fName;
return out;
}
+
+ public int compare(Object v1, Object v2) {
+ Number n1 = (Number) v1;
+ Number n2 = (Number) v2;
+ return nComparator.compare(n1, n2);
+ }
+
public Comparator<Fact> factComparator() {
return fComparator;
@@ -207,5 +214,27 @@
public List<Integer> getIndices() {
return indices;
}
+
+ @Override
+ public boolean equals(Object d_obj) {
+ Domain<?>d = (Domain<?>)d_obj;
+ if (!this.getName().equals(d.getName())) {
+ return false;
+ }
+ else {
+ if (this.discrete) {
+ return (this.fValues.size() == d.getValues().size());
+ } else if (this.fValues.size() != d.getValues().size()) {
+ return false;
+ } else {
+ List<Number> dValues = ((NumericDomain) d).getValues();
+ for (int i = 0 ; i < this.fValues.size() ; i++)
+ if (!this.fValues.get(i).equals(dValues.get(i)))
+ return false;
+ }
+ }
+
+ return true;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,15 +1,17 @@
package dt.memory;
+import java.io.Serializable;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
+import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import dt.tools.Util;
-public class WorkingMemory {
+public class WorkingMemory implements Serializable{ //TODO do not serialize the wm
private Hashtable<String, FactSet> factsets;
@@ -20,6 +22,33 @@
domainset = new Hashtable<String, Domain<?>>();
}
+ public List<Fact> getFacts(Class<?> klass) {
+ Iterator<FactSet> it_fs = this.getFactsets();
+ List<Fact> facts = new ArrayList<Fact>();
+ FactSet klass_fs = null;
+ while (it_fs.hasNext()) {
+ FactSet fs = it_fs.next();
+ if (fs instanceof OOFactSet) {
+ if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
+ // **OPT facts.add(fs);
+ fs.assignTo(facts); // adding all facts of fs to "facts
+ }
+ } else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
+ fs.assignTo(facts); // adding all facts of fs to "facts"
+
+ klass_fs = fs;
+ break;
+ }
+ if (klass.getName() == fs.getClassName()) {
+ klass_fs = fs;
+ }
+ }
+
+// for (Domain<?> d : klass_fs.getDomains())
+// domains.add(d);
+ return facts;
+ }
+
public OOFactSet getFactSet(Class<?> klass, boolean all_discrete) {
String element_class = klass.getName();
//System.out.println("Get the keys:"+ factsets.keys());
Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -0,0 +1,82 @@
+package dt.tools;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+
+import dt.DecisionTree;
+import dt.builder.DecisionTreeBuilder;
+
+public class DecisionTreeSerializer {
+
+ public static void write(Object dt, String file_name) {
+
+ File file =new File(file_name);//"temp.tree"
+
+ if(file.exists()&& (file.length()>0))
+ file.delete(); // should i delete the tree if it already exists??
+
+
+// if(!file.exists())
+// System.out.println("File doesnot exit, creating...");
+
+ try {
+ // Write to disk with FileOutputStream
+ FileOutputStream f_out = new FileOutputStream(file);
+
+ // Write object with ObjectOutputStream
+ ObjectOutputStream obj_out = new ObjectOutputStream (f_out);
+
+ // Write object out to disk
+ obj_out.writeObject ( dt );// fix the serialization of working memory
+ } catch (FileNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ public static Object read(String file_name) throws Exception {
+ File file =new File(file_name);//"temp.tree"
+ if(!file.exists() || (file.length()<=0)) {
+ System.out.println("File doesnot exit, creating...");
+ throw new Exception("File is not found or empty");
+ }
+ try {
+ // Read from disk using FileInputStream
+ FileInputStream f_in = new FileInputStream(file);
+
+ // Read object using ObjectInputStream
+ ObjectInputStream obj_in = new ObjectInputStream (f_in);
+
+ // Read an object
+ Object obj = obj_in.readObject();
+
+ if (obj instanceof DecisionTree || obj instanceof DecisionTreeBuilder) {
+ System.out.println("The object class found");
+ return obj;
+ } else {
+ throw new Exception("There is something else in the decision tree");
+ }
+ } catch (FileNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ return null;
+
+
+
+ }
+}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -83,16 +83,26 @@
List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
- long dt = System.currentTimeMillis();
+ long st = System.currentTimeMillis();
String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
C45TreeBuilder bocuk = new C45TreeBuilder(simple);
- DecisionTree bocuksTree = bocuk.build(emptyObject.getClass(), target_attr, workingAttributes);
- dt = System.currentTimeMillis() - dt;
-// System.out.println("Time" + dt + "\n" + bocuksTree);
-//
+ bocuk.init(target_attr, workingAttributes);
+ DecisionTree bocuksTree = bocuk.build(emptyObject.getClass());
+ long train_time = System.currentTimeMillis();
+
+ System.out.println("\nTime to build" + (train_time-st));
+
+ System.out.println(Util.ntimes("\n", 1)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 1));
+ List<Integer> evaluation = bocuk.test(bocuksTree, bocuk.getFacts());//.subList(339, 340));
+ long test_time = System.currentTimeMillis();
+ System.out.println("Time to test" + (test_time-train_time) + "\n" );
+ System.out.println("TESTING results: Mistakes "+ evaluation.get(0));
+ System.out.println("TESTING results: Corrects "+ evaluation.get(1));
+ System.out.println("TESTING results: Unknown "+ evaluation.get(2));
+
// RulePrinter my_printer = new RulePrinter(bocuk.getNum_fact_trained());
// if (max_rules >0)
// my_printer.setMax_num_rules(max_rules);
@@ -111,4 +121,27 @@
}
+
+
+ public static List<Object> test_process(WorkingMemory simple, Object emptyObject, String datafile, String separator) {
+
+ try {
+ long st = System.currentTimeMillis();
+ boolean all_discrete = false;
+ List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
+ long process_time = System.currentTimeMillis();
+
+ System.out.println("\nTime to process_objects " + (process_time-st));
+//
+ return obj_read;
+
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ return null;
+
+
+ }
+
}
Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java 2008-04-20 14:11:35 UTC (rev 19651)
@@ -6,10 +6,11 @@
public class Util {
+ public static boolean RUN = true;
+ public static boolean DEBUG = true;
+ public static boolean DEBUG_RETRAIN = true;
+ public static boolean DEBUG_TEST = false;
- public static boolean DEBUG = false;
- public static boolean DEBUG_TEST = true;
-
public static String ntimes(String s,int n){
StringBuffer buf = new StringBuffer();
for (int i = 0; i < n; i++) {
More information about the jboss-svn-commits
mailing list