[jboss-svn-commits] JBL Code SVN: r19651 - in labs/jbossrules/contrib/machinelearning/decisiontree/src/dt: builder and 2 other directories.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Sun Apr 20 10:11:36 EDT 2008


Author: gizil
Date: 2008-04-20 10:11:35 -0400 (Sun, 20 Apr 2008)
New Revision: 19651

Added:
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java
Modified:
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
   labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
Log:
serializing decision tree and tree builder + the dumbest way of retraining a decision tree(= iterationg over an existing tree)
  


Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/DecisionTree.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,15 +1,22 @@
 package dt;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Hashtable;
 import java.util.List;
 
 import dt.memory.Domain;
 import dt.memory.Fact;
-import dt.tools.Util;
 
-public class DecisionTree {
+public class DecisionTree implements Serializable{
 
+
+
+	/**
+	 * 
+	 */
+	private static final long serialVersionUID = 1L;
+
 	public long FACTS_READ = 0;
 
 	/* set of the attributes, their types */

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/LeafNode.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -12,13 +12,34 @@
 	private Object targetValue;
 	private double rank;
 	private int num_facts_classified;
+	
+	private Fact pseudo_f;
 
 	public LeafNode(Domain<?> targetDomain, Object value){
 		super(targetDomain);
 		this.targetValue = value;
 		num_facts_classified = 0;
+		
+		this.pseudo_f = new Fact();
+		this.setPseudoFact();
 	}
+	public void setTargetValue(Object value) {
+		this.targetValue = value;
+		this.pseudo_f = new Fact();
+		this.setPseudoFact();
+	}
 	
+	public void setPseudoFact() {
+		try {
+			pseudo_f.add(this.getDomain(), this.getValue());
+		} catch (Exception e) {
+			System.out.println(Util.ntimes("\n", 10)+"Unknown situation at leafnode: " + this.getValue() + " @ "+ this.getDomain());
+			e.printStackTrace();
+			// Unknown
+			System.exit(0);
+
+		}
+	}
 	public void addNode(Object attributeValue, TreeNode node) {
 		throw new RuntimeException("cannot add Node to a leaf node");
 	}
@@ -42,26 +63,14 @@
 	public Integer evaluate(Fact f) {
 		
 		Domain<?> target_domain = this.getDomain();
-		Fact pseudo_f = new Fact();
-		try {
-			pseudo_f.add(target_domain, this.getValue());
-			Comparator<Fact> targetComp = target_domain.factComparator();
-			if (targetComp.compare(f, pseudo_f) == 0 ) {
-				return Integer.valueOf(1);
-			} else {
-				return Integer.valueOf(0);
-			}
-		} catch (Exception e) {
-			
-			System.out.println(Util.ntimes("\n", 10)+"Unknown situation at leafnode: " + this.getValue() + " @ "+ target_domain);
-			e.printStackTrace();
-			// Unknown
-			System.exit(0);
-			return Integer.valueOf(2);
-		}
 		
+		Comparator<Fact> targetComp = target_domain.factComparator();
+		if (targetComp.compare(f, this.pseudo_f) == 0 ) {
+			return Integer.valueOf(1); 	//correct
+		} else {
+			return Integer.valueOf(0);	// mistake
+		}		
 		
-		
 	}
 	
 	public String toString(){

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/TreeNode.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,4 +1,5 @@
 package dt;
+import java.io.Serializable;
 import java.util.Collection;
 import java.util.Hashtable;
 
@@ -7,7 +8,7 @@
 import dt.tools.Util;
 
 
-public class TreeNode {
+public class TreeNode implements Serializable{
 	
 	private Domain<?> domain;
 	private Hashtable<Object, TreeNode> children;

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/C45TreeBuilder.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,8 @@
 package dt.builder;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.List;
@@ -10,14 +10,12 @@
 import dt.DecisionTree;
 import dt.LeafNode;
 import dt.TreeNode;
-
+import dt.memory.Domain;
+import dt.memory.Fact;
 import dt.memory.FactDistribution;
-import dt.memory.FactTargetDistribution;
-import dt.memory.WorkingMemory;
-import dt.memory.Fact;
 import dt.memory.FactSet;
 import dt.memory.OOFactSet;
-import dt.memory.Domain;
+import dt.memory.WorkingMemory;
 import dt.tools.FactProcessor;
 import dt.tools.Util;
 
@@ -42,10 +40,14 @@
 	MyThread helper;
 	private int FUNC_CALL = 0;
 	protected int num_fact_trained = 0;
+	private ArrayList<Fact> facts;
+	private ArrayList<Fact> training_facts;
 	private ArrayList<Fact> unclassified_facts;
-	private ArrayList<Fact> training_facts;
+	
 	private WorkingMemory global_wm;
 	private List<Domain<?>> domains;
+	private String target;
+	private List<String> attributes;
 	
 	/*
 	 * treebuilder.execute(workingmemory, classtoexecute, attributestoprocess)
@@ -56,23 +58,43 @@
 	 * internalprocess(attributestoprocess)
 	 */
 	public C45TreeBuilder(WorkingMemory wm) {
+	
+		global_wm = wm;
 		
+		facts = new ArrayList<Fact>();
+		training_facts = new ArrayList<Fact>();
 		unclassified_facts = new ArrayList<Fact>();
-		training_facts = new ArrayList<Fact>();
-		global_wm = wm;
+	
+		target = null;
+		attributes = new ArrayList<String>();
 		domains = new ArrayList<Domain<?>>();
-	
 	}
 	
 	public C45TreeBuilder() {
 		
+		facts = new ArrayList<Fact>();
+		training_facts = new ArrayList<Fact>();
 		unclassified_facts = new ArrayList<Fact>();
-		training_facts = new ArrayList<Fact>();
+
+		target = null;
+		attributes = new ArrayList<String>();
 		domains = new ArrayList<Domain<?>>();
-	
+		
 	}
+	/* set the builder's 
+	 *  domains
+	 */
+	public void setDomains(Class<?> klass) {
+		FactSet klass_fs = null;
+		
+		for (Domain<?> d : klass_fs.getDomains())
+			domains.add(d);
+	}
 	
-	
+	/* set the builder's 
+	 * 	facts
+	 *  domains
+	 */
 	private void setKlass(Class<?> klass) {
 		Iterator<FactSet> it_fs = global_wm.getFactsets();
 		FactSet klass_fs = null;
@@ -81,10 +103,10 @@
 			if (fs instanceof OOFactSet) {
 				if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
 					// **OPT facts.add(fs);
-					fs.assignTo(training_facts); // adding all facts of fs to "facts
+					fs.assignTo(facts); // adding all facts of fs to "facts
 				}
 			} else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
-				fs.assignTo(training_facts); // adding all facts of fs to "facts"
+				fs.assignTo(facts); // adding all facts of fs to "facts"
 
 				klass_fs = fs;
 				break;
@@ -98,8 +120,47 @@
 			domains.add(d);
 	}
 	
-	private void init(DecisionTree dt, String targetField, List<String> workingAttributes) {
+	/* initialize the builder's 
+	 * 	targetField
+	 * 	the attribute list (workingAttributes != null ? workingAttributes : domains )
+	 */
+	public void init(String targetField, List<String> workingAttributes) {
+		this.setTarget(targetField);
+		if (workingAttributes != null)
+			for (String attr : workingAttributes) {
+				this.addAttribute(attr);
+			}
+		else {
+			for (Domain<?> d : domains) {
+				this.addAttribute(d.getName());	
+			}
+		}
+		
+	}
+	
+	public void setTarget(String targetField) {
+		this.target = targetField;
+		//attrsToClassify.remove(target);
+	}
+	
+	public void addDomain(Domain<?> d) {
+		//if (!attribute.equals(this.target))
+			//attributes.add(d.getName());
+			domains.add(d);
+	}
+	public void addAttribute(String attribute) {
+		//if (!attribute.equals(this.target))
+			attributes.add(attribute);
+	}
+	private void init_dt(DecisionTree dt, String targetField) {
 		dt.setTarget(targetField);
+		for (Domain<?> d : domains) {
+			dt.addDomain(d);	
+		}
+		
+	}
+	private void init_dt(DecisionTree dt, String targetField, List<String> workingAttributes) {
+		dt.setTarget(targetField);
 		if (workingAttributes != null)
 			for (String attr : workingAttributes) {
 				dt.addDomain(global_wm.getDomain(attr));
@@ -111,23 +172,18 @@
 		}
 		
 	}
-	
-	/* building with a training and test */
-	public DecisionTree build(Class<?> klass, String targetField, List<String> workingAttributes) {
+	/* building with the training set (all relative facts from wm) from scratch*/
+	public DecisionTree build(Class<?> klass) {
 		/* gets the facts the decision tree is eligible */
 		setKlass(klass);
 		
 		DecisionTree dt = new DecisionTree(klass.getName());
-		init(dt, targetField, workingAttributes);
+		init_dt(dt, this.target, this.attributes);
 		
-		
-		DecisionTree best_dt = new DecisionTree(klass.getName());
-		init(dt, targetField, workingAttributes);
-		
 		ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
 		Collections.sort(attrs);
 		
-		
+		training_facts.addAll(facts);
 		dt.FACTS_READ += training_facts.size();
 		/* you must set this when the training called the first time */
 		setNum_fact_trained(training_facts.size());
@@ -135,20 +191,29 @@
 		//while ()
 		TreeNode root = train(dt, training_facts, attrs);
 		dt.setRoot(root);
+	
 		
+		return dt;
+	}
+		
+	
+	/* building with a training and test */
+	public DecisionTree build_test(Class<?> klass, String targetField, List<String> workingAttributes) {
+		if (this.target == null) {
+			System.out.println("Target is not set");
+			System.exit(0);
+		}
+		/* gets the facts the decision tree is eligible */
+		DecisionTree tree = build(klass);
+		
 		System.out.println(Util.ntimes("\n", 2)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 2));
-		List<Integer> evaluation = test(dt, training_facts.subList(339, 340));
+		List<Integer> evaluation = test(tree, training_facts);//.subList(339, 340));
 
 		System.out.println("TESTING results: Mistakes "+ evaluation.get(0));
 		System.out.println("TESTING results: Corrects "+ evaluation.get(1));
-		System.out.println("TESTING results: Unknown "+ evaluation.get(2));
-		if (evaluation.get(1) == training_facts.size()) {
-			best_dt.setRoot(root);
-		}
-		
-		return dt;
+		System.out.println("TESTING results: Unknown "+ evaluation.get(2) +" OF "+  training_facts.size() + " facts");
+		return tree;
 	}
-	
 
 	public DecisionTree build(WorkingMemory wm, Class<?> klass,
 			String targetField, List<String> workingAttributes) {
@@ -199,50 +264,32 @@
 		return dt;
 	}
 
-/*	public DecisionTree build(WorkingMemory wm, String klass,
-			String targetField, List<String> workingAttributes) {
-		unclassified_facts = new ArrayList<Fact>();
-		DecisionTree dt = new DecisionTree(klass);
-		// **OPT List<FactSet> facts = new ArrayList<FactSet>();
-		ArrayList<Fact> facts = new ArrayList<Fact>();
-		FactSet klass_fs = null;
-		Iterator<FactSet> it_fs = wm.getFactsets();
-		while (it_fs.hasNext()) {
-			FactSet fs = it_fs.next();
-			if (klass == fs.getClassName()) {
-				// **OPT facts.add(fs);
-				fs.assignTo(facts); // adding all facts of fs to "facts"
-
-				klass_fs = fs;
-				break;
-			}
-		}
-		dt.FACTS_READ += facts.size();
-		setNum_fact_processed(facts.size());
-
-		if (workingAttributes != null)
-			for (String attr : workingAttributes) {
-				//System.out.println("Bok degil " + attr);
-				dt.addDomain(klass_fs.getDomain(attr));
-			}
-		else
-			for (Domain<?> d : klass_fs.getDomains())
-				dt.addDomain(d);
-
-		dt.setTarget(targetField);
-
+	/* building with the training set (some part of the facts) */
+	public DecisionTree build(Class<?> klass, List<Fact> first_facts) {
+		/* gets the facts which the decision tree is eligible */
+		//setKlass(klass);
+		
+		DecisionTree dt = new DecisionTree(klass.getName());
+		init_dt(dt, this.target); // initialize the decision tree with the target and all domains
+		
 		ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
 		Collections.sort(attrs);
+		
+		training_facts.addAll(first_facts);
+		dt.FACTS_READ += first_facts.size();
+		/* you must set this when the training called the first time */
+		setNum_fact_trained(training_facts.size());
 
-		TreeNode root = c45(dt, facts, attrs);
+		//while ()
+		TreeNode root = train(dt, training_facts, attrs);
 		dt.setRoot(root);
-
+		
+		
 		return dt;
-	}*/
+	}
+	
+	public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames) {
 
-	public TreeNode train(DecisionTree dt, List<Fact> facts,
-			List<String> attributeNames) {
-
 		FUNC_CALL++;
 		if (facts.size() == 0) {
 			throw new RuntimeException("Nothing to classify, factlist is empty");
@@ -285,7 +332,7 @@
 //		String chosenAttribute = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
 //		List<?> categorization = dt.getPossibleValues(chosenAttribute);
 		Domain<?> choosenDomain = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
-		System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
+		if (Util.RUN)	System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
 
 		TreeNode currentNode = new TreeNode(choosenDomain);
 			
@@ -321,8 +368,161 @@
 		return currentNode;
 	}
 
+	/* building with the training set (some part of the facts) */
+	public DecisionTree re_build(DecisionTree dt, List<Fact> new_facts) {
+		
+		ArrayList<String> attrs = new ArrayList<String>(dt.getAttributes());
+		Collections.sort(attrs);
+		
+		training_facts.addAll(new_facts);
+		dt.FACTS_READ += new_facts.size();
+		/* you must set this when the training called the first time */
+		setNum_fact_trained(training_facts.size());
+		System.out.println(Util.ntimes("\n", 10)+"How facts are u training? "+ training_facts.size());
+		//while ()
+		TreeNode root = re_train(dt, dt.getRoot(), training_facts, attrs);
+		dt.setRoot(root);
+		
+		return dt;
+	}
 	
+	public TreeNode re_train(DecisionTree dt, TreeNode currentNode, List<Fact> facts, List<String> attributeNames) {
+
+		FUNC_CALL++;
+		if (facts.size() == 0) {
+			throw new RuntimeException("Nothing to classify, factlist is empty");
+		}
+		/* let's get the statistics of the results */
+		// List<?> targetValues = dt.getPossibleValues(dt.getTarget());
+		//Hashtable<Object, Integer> stats_ = dt.getStatistics(facts, dt.getTarget());// targetValues
+		
+		//FactTargetDistribution stats = dt.getDistribution(facts);
+		
+		FactDistribution stats = new FactDistribution(dt.getDomain(dt.getTarget()));
+		stats.calculateDistribution(facts);
+		stats.evaluateMajority();
+
+		/* if all elements are classified to the same value */
+		if (stats.getNum_supported_target_classes() == 1) {
+
+			LeafNode classifiedNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+			classifiedNode.setRank((double) facts.size()/(double) getNum_fact_trained());
+			classifiedNode.setNumSupporter(facts.size());
+			
+			return classifiedNode;
+		}
+
+		/* if there is no attribute left in order to continue */
+		if (attributeNames.size() == 0) {
+			/* an heuristic of the leaf classification */
+			Object winner = stats.getThe_winner_target_class();
+			LeafNode noAttributeLeftNode = new LeafNode(dt.getDomain(dt.getTarget()), winner);
+			noAttributeLeftNode.setRank((double) stats.getVoteFor(winner)/ (double) num_fact_trained);
+			noAttributeLeftNode.setNumSupporter(stats.getVoteFor(winner));
+			
+			/* we need to know how many guys cannot be classified and who these guys are */
+			FactProcessor.splitUnclassifiedFacts(unclassified_facts, stats);
+			
+			return noAttributeLeftNode;
+		}
+
+		/* choosing the attribute for the branching starts */
+//		String chosenAttribute = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+//		List<?> categorization = dt.getPossibleValues(chosenAttribute);
+		Domain<?> choosenDomain = Entropy.chooseContAttribute(dt, facts, stats, attributeNames);
+		if (Util.RUN)	System.out.println(Util.ntimes("*", 20) + " 1st best attr: "+ choosenDomain.getName());
+		else if (FUNC_CALL % 100 ==0){
+			System.out.print(".");
+		}
+		
+		Hashtable<Object, List<Fact>> filtered_facts = FactProcessor.splitFacts(facts, choosenDomain);
+		for (Object value : filtered_facts.keySet()) {
+			if (filtered_facts.get(value).isEmpty()){
+				@SuppressWarnings("unused")
+				boolean bok = true;
+			}
+		}
+		dt.FACTS_READ += facts.size();
+		
+		if (currentNode.getDomain() == choosenDomain) {
+			
+			
+			for (Object value : filtered_facts.keySet()) {
+				
+				TreeNode childNode = currentNode.getChild(value);
+				/* split the last two class at the same time */
+
+				ArrayList<String> attributeNames_copy = new ArrayList<String>(
+						attributeNames);
+				attributeNames_copy.remove(choosenDomain.getName());
+
+				if (filtered_facts.get(value).isEmpty()) {
+					/* majority !!!! */
+					//Comparator<Fact> targetComp = dt.getDomain(dt.getTarget()).factComparator(); 
+					
+					if (childNode == null || !(childNode instanceof LeafNode)) {
+						LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+						majorityNode.setRank(-1.0); // classifying nothing
+						majorityNode.setNumSupporter(filtered_facts.get(value).size());
+						
+						childNode = majorityNode; // How to set this guy
+						if (childNode == null)
+							currentNode.addNode(value, childNode);
+					}
+					
+					else {
+						/* have to remove the leafnode from the children list with key value*/
+						((LeafNode)childNode).setRank(-1.0); // classifying nothing
+						((LeafNode)childNode).setNumSupporter(filtered_facts.get(value).size());
+						
+						if (dt.getDomain(dt.getTarget()).compare(((LeafNode)childNode).getValue(), value)!=0) {
+							((LeafNode)childNode).setTargetValue(value);
+							//currentNode.
+						}
+					} 
+					
+						
+					
+				} else {
+					if (childNode == null) {
+						TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
+						currentNode.addNode(value, newNode);
+					}
+					TreeNode newNode = re_train(dt, childNode, filtered_facts.get(value), attributeNames_copy);
+					//currentNode.addNode(value, newNode);
+				}
+			}
+			
+		} else {
+			currentNode = new TreeNode(choosenDomain);
+			
+			for (Object value : filtered_facts.keySet()) {
+			
+				/* split the last two class at the same time */
+
+				ArrayList<String> attributeNames_copy = new ArrayList<String>(
+						attributeNames);
+				attributeNames_copy.remove(choosenDomain.getName());
+
+				if (filtered_facts.get(value).isEmpty()) {
+					/* majority !!!! */
+					LeafNode majorityNode = new LeafNode(dt.getDomain(dt.getTarget()), stats.getThe_winner_target_class());
+					majorityNode.setRank(-1.0); // classifying nothing
+					majorityNode.setNumSupporter(filtered_facts.get(value).size());
+					currentNode.addNode(value, majorityNode);
+				} else {
+					
+					TreeNode newNode = train(dt, filtered_facts.get(value), attributeNames_copy);
+					currentNode.addNode(value, newNode);
+				}
+			}
+		}
+
+		return currentNode;
+	}
+
 	
+	
 	public List<Integer> test(DecisionTree dt, List<Fact> facts) {
 		/*
 		 * false | true | unknown
@@ -336,18 +536,39 @@
 		
 		int i = 0;
 		for (Fact f : facts) {
+
+			Integer result = dt.test(f);
 			if (Util.DEBUG_TEST) {
-				System.out.println(Util.ntimes("#\n", 5)+i+ " <START> TEST: f="+ f);
-				//System.exit(0);
-			}
-			Integer result = dt.test(f);
-			
+				System.out.println(Util.ntimes("#\n", 1)+i+ " <START> TEST: f="+ f + " = target "+ result);
+			} else
+				if (i%1000 ==0)	System.out.print(".");
 			results.set(result, Integer.valueOf(results.get(result) + 1));
 			i ++;
 		}
 		return results;
 		
 	}
+	
+	public List<Fact> getFacts(int fromIndex, int toIndex) {
+		return facts.subList(fromIndex, toIndex); //.iterator();
+	}
+		
+	public List<Fact> getFacts() {
+		return facts; //.iterator();
+	}
+	
+	public List<Fact> getTrainingFacts() {
+		return training_facts; //.iterator();
+	}
+	
+	public List<Fact> getUnClassifiedFacts() {
+		return unclassified_facts; //.iterator();
+	}
+	
+	
+	public int getNumUnClassifiedFacts() {
+		return unclassified_facts.size(); //.iterator();
+	}
 
 	public int getNumCall() {
 		return FUNC_CALL;

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/DecisionTreeBuilder.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,5 +1,6 @@
 package dt.builder;
 
+import java.io.Serializable;
 import java.util.List;
 
 import dt.DecisionTree;
@@ -7,7 +8,7 @@
 import dt.memory.Fact;
 import dt.memory.WorkingMemory;
 
-public interface DecisionTreeBuilder {
+public interface DecisionTreeBuilder extends Serializable{
 	
 	
 	DecisionTree build(WorkingMemory wm, Class<?> klass, String targetField, List<String> workingAttributes);
@@ -15,7 +16,6 @@
 	public TreeNode train(DecisionTree dt, List<Fact> facts, List<String> attributeNames);
 	public List<Integer> test(DecisionTree dt, List<Fact> facts);
 	//DecisionTree build(WorkingMemory simple, String klass_name, String target_attr,List<String> workingAttributes);
-
 	int getNum_fact_trained();
 	void setNum_fact_trained(int num);
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/builder/Entropy.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -34,7 +34,8 @@
 			
 //			if (attr.equalsIgnoreCase(targetDomain.getName()))
 //				continue;
-			System.out.println("Which attribute to try: "+ attr);
+			if (Util.RUN)	System.out.println("Which attribute to try: "+ attr);
+			
 			double gain = 0;
 			if (dt.getDomain(attr).isDiscrete()) {
 				/* */
@@ -76,7 +77,9 @@
 //				if (!bestDomain.isDiscrete())
 //					bestDomain.setIndices(split_indices);
 				
-				System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
+				if (Util.RUN)	
+					System.out.println(Util.ntimes("\n",3)+Util.ntimes("!",10)+" NEW BEST "+attributeWithGreatestGain + " the gain "+greatestGain );
+
 			}
 			
 			if (attr.equalsIgnoreCase("c2"))

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/BooleanDomain.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -141,5 +141,17 @@
 	public void addIndex(int index) {
 		// TODO Auto-generated method stub	
 	}
+	
+	@Override
+	public boolean equals(Object d_obj) {
+		Domain<?>d = (Domain<?>)d_obj;
+		return (this.getName().equals(d.getName()));
+	}
+	
+	public int compare(Object v1, Object v2) {
+		Boolean b1 = (Boolean) v1;
+		Boolean b2 = (Boolean) v2;
+		return b1.equals(b2) ? 0 : 1;
+	}
 
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Domain.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,9 +1,10 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.util.Comparator;
 import java.util.List;
 
-public interface Domain<T> {
+public interface Domain<T> extends Serializable {
 	
 	boolean isConstant();
 	void setConstant();
@@ -38,6 +39,9 @@
 	List<Integer> getIndices();
 	void addIndex(int index);
 	
+	
+	int compare(Object o1, Object o2);
+	
 }
 
 

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/Fact.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,16 +1,18 @@
 package dt.memory;
 
+import java.io.Serializable;
+import java.util.HashMap;
 import java.util.Hashtable;
 import java.util.Set;
 
 
-public class Fact {
+public class Fact implements Serializable{
 
 	private Hashtable<String, Domain<?>> fields;
-	private Hashtable<String, Object> values;
+	private HashMap<String, Object> values;
 
 	public Fact() {
-		this.values = new Hashtable<String, Object>();
+		this.values = new HashMap<String, Object>();
 		this.fields = new Hashtable<String, Domain<?>>();
 		/* while creating the fact i should add the possible keys, the valid domains */
 	}
@@ -19,7 +21,7 @@
 		this.fields = new Hashtable<String, Domain<?>>();
 		for (Domain<?> d: domains)
 			this.fields.put(d.getName(), d);
-		this.values = new Hashtable<String, Object>();
+		this.values = new HashMap<String, Object>();
 		//this.attributes. of the keys are only these domains
 		/* while creating the fact i should add the possible keys, the valid domains */
 	}

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactLiteralAttributeComparator.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,10 +1,11 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.util.Comparator;
 
 import dt.tools.Util;
 
-public class FactLiteralAttributeComparator  implements Comparator<Fact> {
+public class FactLiteralAttributeComparator  implements Comparator<Fact>, Serializable {
 	private String attr_name;
 
 	public FactLiteralAttributeComparator(String _attr_name) {

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactNumericAttributeComparator.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.util.Comparator;
 
-public class FactNumericAttributeComparator implements Comparator<Fact> {
+public class FactNumericAttributeComparator implements Comparator<Fact>, Serializable {
 	private String attr_name;
 
 	public FactNumericAttributeComparator(String _attr_name) {

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/FactSet.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.util.Collection;
 
-public interface FactSet {
+public interface FactSet extends Serializable {
 
 	String getClassName();
 	

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/LiteralDomain.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -168,5 +168,34 @@
 	public void addIndex(int index) {
 		// TODO Auto-generated method stub	
 	}
+	
+	@Override
+	public boolean equals(Object d_obj) {
+		Domain<?>d = (Domain<?>)d_obj;
+		if (!this.getName().equals(d.getName())) {
+			return false;
+		} 
+		else { 
+			if (this.discrete) {
+				return (this.fValues.size() == d.getValues().size());
+			} else if (this.fValues.size() != d.getValues().size()) {
+					return false;
+				} else {
+					List<String> dValues = ((LiteralDomain) d).getValues();
+					for (int i = 0 ; i < this.fValues.size() ; i++)
+						if (!this.fValues.get(i).equals(dValues.get(i)))
+							return false;
+				}
+		} 
+		
+		return true;
+	}
+	
+	
+	public int compare(Object v1, Object v2) {
+		String s1 = (String) v1;
+		String s2 = (String) v2;
+		return s1.equals(s2) ? 0 : 1;
+	}
 
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumberComparator.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,8 +1,9 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.util.Comparator;
 
-public class NumberComparator implements Comparator<Number> {
+public class NumberComparator implements Comparator<Number>, Serializable {
 	public NumberComparator() {
 	}
 

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/NumericDomain.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -190,6 +190,13 @@
 		String out = fName;
 		return out;
 	}
+	
+	public int compare(Object v1, Object v2) {
+		Number n1 = (Number) v1;
+		Number n2 = (Number) v2;
+		return nComparator.compare(n1, n2);
+	}
+	
 
 	public Comparator<Fact> factComparator() {
 		return fComparator;
@@ -207,5 +214,27 @@
 	public List<Integer> getIndices() {
 		return indices;
 	}
+	
+	@Override
+	public boolean equals(Object d_obj) {
+		Domain<?>d = (Domain<?>)d_obj;
+		if (!this.getName().equals(d.getName())) {
+			return false;
+		} 
+		else { 
+			if (this.discrete) {
+				return (this.fValues.size() == d.getValues().size());
+			} else if (this.fValues.size() != d.getValues().size()) {
+					return false;
+				} else {
+					List<Number> dValues = ((NumericDomain) d).getValues();
+					for (int i = 0 ; i < this.fValues.size() ; i++)
+						if (!this.fValues.get(i).equals(dValues.get(i)))
+							return false;
+				}
+		} 
+		
+		return true;
+	}
 
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/memory/WorkingMemory.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -1,15 +1,17 @@
 package dt.memory;
 
+import java.io.Serializable;
 import java.lang.annotation.Annotation;
 import java.lang.reflect.Field;
 import java.lang.reflect.Method;
+import java.util.ArrayList;
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.List;
 
 import dt.tools.Util;
 
-public class WorkingMemory {
+public class WorkingMemory implements Serializable{ //TODO do not serialize the wm
 	
 	private Hashtable<String, FactSet> factsets;
 
@@ -20,6 +22,33 @@
 		domainset = new Hashtable<String, Domain<?>>();
 	}
 	
+	public List<Fact> getFacts(Class<?> klass) {
+		Iterator<FactSet> it_fs = this.getFactsets();
+		List<Fact> facts = new ArrayList<Fact>();
+		FactSet klass_fs = null;
+		while (it_fs.hasNext()) {
+			FactSet fs = it_fs.next();
+			if (fs instanceof OOFactSet) {
+				if (klass.isAssignableFrom(((OOFactSet) fs).getFactClass())) {
+					// **OPT facts.add(fs);
+					fs.assignTo(facts); // adding all facts of fs to "facts
+				}
+			} else if (klass.getName().equalsIgnoreCase(fs.getClassName())) {
+				fs.assignTo(facts); // adding all facts of fs to "facts"
+
+				klass_fs = fs;
+				break;
+			}
+			if (klass.getName() == fs.getClassName()) {
+				klass_fs = fs;
+			}
+		}
+		
+//		for (Domain<?> d : klass_fs.getDomains())
+//			domains.add(d);
+		return facts;
+	}
+	
 	public OOFactSet getFactSet(Class<?> klass, boolean all_discrete) {
 		String element_class = klass.getName();
 		//System.out.println("Get the keys:"+ factsets.keys());

Added: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/DecisionTreeSerializer.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -0,0 +1,82 @@
+package dt.tools;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+
+import dt.DecisionTree;
+import dt.builder.DecisionTreeBuilder;
+
+public class DecisionTreeSerializer {
+	
+	public static void write(Object dt, String file_name) {
+		
+		File file =new File(file_name);//"temp.tree"
+		
+		if(file.exists()&& (file.length()>0))
+			file.delete();	// should i delete the tree if it already exists??
+		
+		
+//		if(!file.exists())
+//			System.out.println("File doesnot exit, creating...");
+		
+		try {
+			// Write to disk with FileOutputStream
+			FileOutputStream f_out = new FileOutputStream(file);
+			
+			// Write object with ObjectOutputStream
+			ObjectOutputStream obj_out = new ObjectOutputStream (f_out);
+
+			// Write object out to disk
+			obj_out.writeObject ( dt );// fix the serialization of working memory
+		} catch (FileNotFoundException e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		} catch (IOException e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		}
+	}
+	
+	public static Object read(String file_name) throws Exception {	
+		File file =new File(file_name);//"temp.tree"
+		if(!file.exists() || (file.length()<=0)) {
+			System.out.println("File doesnot exit, creating...");
+			throw new Exception("File is not found or empty");
+		}
+		try {
+			// Read from disk using FileInputStream
+			FileInputStream f_in = new FileInputStream(file);
+			
+			// Read object using ObjectInputStream
+			ObjectInputStream obj_in = new ObjectInputStream (f_in);
+
+			// Read an object
+			Object obj = obj_in.readObject();
+
+			if (obj instanceof DecisionTree || obj instanceof DecisionTreeBuilder) {
+				System.out.println("The object class found");
+				return obj;
+			} else {
+				throw new Exception("There is something else in the decision tree");
+			}
+		} catch (FileNotFoundException e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		} catch (IOException e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		} catch (ClassNotFoundException e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		}
+		return null;
+		
+		
+		
+	}
+}

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/FileProcessor.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -83,16 +83,26 @@
 			List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
 			
 
-			long dt = System.currentTimeMillis();
+			long st = System.currentTimeMillis();
 			String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
 			
 			List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
 			
 			C45TreeBuilder bocuk = new C45TreeBuilder(simple);
-			DecisionTree bocuksTree = bocuk.build(emptyObject.getClass(), target_attr, workingAttributes);
-			dt = System.currentTimeMillis() - dt;
-//			System.out.println("Time" + dt + "\n" + bocuksTree);
-//
+			bocuk.init(target_attr, workingAttributes);
+			DecisionTree bocuksTree = bocuk.build(emptyObject.getClass());
+			long train_time = System.currentTimeMillis();
+			
+			System.out.println("\nTime to build" + (train_time-st));
+			
+			System.out.println(Util.ntimes("\n", 1)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 1));
+			List<Integer> evaluation = bocuk.test(bocuksTree, bocuk.getFacts());//.subList(339, 340));
+			long test_time = System.currentTimeMillis();
+			System.out.println("Time to test" + (test_time-train_time) + "\n" );
+			System.out.println("TESTING results: Mistakes "+ evaluation.get(0));
+			System.out.println("TESTING results: Corrects "+ evaluation.get(1));
+			System.out.println("TESTING results: Unknown "+ evaluation.get(2));
+			
 //			RulePrinter my_printer  = new RulePrinter(bocuk.getNum_fact_trained());
 //			if (max_rules >0)
 //				my_printer.setMax_num_rules(max_rules);
@@ -111,4 +121,27 @@
 
 		
 	}
+	
+	
+	public static List<Object> test_process(WorkingMemory simple, Object emptyObject, String datafile, String separator) {
+
+		try {
+			long st = System.currentTimeMillis();
+			boolean all_discrete = false;
+			List<Object> obj_read=FactSetFactory.fromFileAsObject(simple, emptyObject.getClass(), datafile, separator, all_discrete);
+			long process_time = System.currentTimeMillis();
+			
+			System.out.println("\nTime to process_objects " + (process_time-st));
+//			
+			return obj_read;
+			
+		} catch (Exception e) {
+			// TODO Auto-generated catch block
+			e.printStackTrace();
+		}
+		return null;
+
+		
+	}
+	
 }

Modified: labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java	2008-04-20 13:53:37 UTC (rev 19650)
+++ labs/jbossrules/contrib/machinelearning/decisiontree/src/dt/tools/Util.java	2008-04-20 14:11:35 UTC (rev 19651)
@@ -6,10 +6,11 @@
 
 
 public class Util {
+	public static boolean RUN = true;
+	public static boolean DEBUG = true;
+	public static boolean DEBUG_RETRAIN = true;
+	public static boolean DEBUG_TEST = false;
 	
-	public static boolean DEBUG = false;
-	public static boolean DEBUG_TEST = true;
-	
 	public static String ntimes(String s,int n){
 		StringBuffer buf = new StringBuffer();
 		for (int i = 0; i < n; i++) {




More information about the jboss-svn-commits mailing list