[jboss-svn-commits] JBL Code SVN: r19715 - labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples.

jboss-svn-commits at lists.jboss.org jboss-svn-commits at lists.jboss.org
Thu Apr 24 16:24:01 EDT 2008


Author: gizil
Date: 2008-04-24 16:24:01 -0400 (Thu, 24 Apr 2008)
New Revision: 19715

Added:
   labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationMemoryTest.java
Modified:
   labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationTest.java
Log:
to test the last re-training (the one with memory - saving the matching facts)


Added: labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationMemoryTest.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationMemoryTest.java	                        (rev 0)
+++ labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationMemoryTest.java	2008-04-24 20:24:01 UTC (rev 19715)
@@ -0,0 +1,385 @@
+package examples;
+
+import java.io.InputStreamReader;
+import java.io.Reader;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.drools.RuleBase;
+import org.drools.RuleBaseFactory;
+import org.drools.StatefulSession;
+import org.drools.audit.WorkingMemoryFileLogger;
+import org.drools.compiler.PackageBuilder;
+import org.drools.event.DebugAgendaEventListener;
+import org.drools.event.DebugWorkingMemoryEventListener;
+import org.drools.rule.Package;
+
+import dt.DecisionTree;
+import dt.TreeNode;
+import dt.builder.C45TreeBuilder;
+import dt.builder.C45TreeIterator;
+import dt.memory.Fact;
+import dt.memory.WorkingMemory;
+import dt.tools.DecisionTreeSerializer;
+import dt.tools.FileProcessor;
+import dt.tools.ObjectReader;
+import dt.tools.RulePrinter;
+import dt.tools.Util;
+
+public class IterationMemoryTest {
+	public static final void main(final String[] args) throws Exception {
+		int obj_type= 0;
+		WorkingMemory simple = new WorkingMemory();
+		
+		ArrayList<Fact> _facts = null;
+		String drlFile, inputFile, directory;
+		String build_file = "bocuks.iterator", tree_file = "bocuks_new.tree";
+		Object obj;
+		
+		//obj_type= 1;
+		switch (obj_type) {
+		case 1:
+			drlFile = new String("poker_hands_" + ".drl");
+			inputFile = new String("data/poker/poker-hand-training-true.data.txt");
+			directory = new String("data/poker/");	
+			obj = new Poker();
+			
+			break;
+		default:
+			drlFile = new String("golf_" + ".drl");
+			inputFile = new String("data/golf/golf.data.txt");
+			directory = new String("data/golf/");
+			obj = new Golf();
+		}
+		
+		int input_object = 1;	// process text from file
+		//int input_object = 2;	// read from file
+		
+		List<Object> my_objects = null;
+		switch (input_object) {	
+		case 1:	// process from file
+			my_objects = FileProcessor.test_process(simple, obj, inputFile, ",");
+			_facts = simple.getFacts(obj.getClass());
+			break;
+		case 2:	// read from file
+			break;
+		}
+		
+		if (my_objects == null)	{
+			System.out.println("No objects found to process returning  ");
+			return;
+		}
+		
+		
+		int input_tree = 0;	// read from file
+		//int input_tree = 1;		// train a new tree with some part of the facts
+		int[] test_size = {0, _facts.size()}, train_size = {0, _facts.size()};
+		int[] retrain_size = new int[3];
+		boolean retrain_tree = true, test_tree = true, print_tree = true, write_tree = true, test_rules = false, parse_w_drools = false;
+		switch (obj_type) {
+		case 1:
+			test_size[1] = 2;
+			train_size[1] = 2;
+			retrain_size [1] = 2;
+			break;
+		default:
+			test_size[1] = 14;
+			train_size[1] = 4;
+			retrain_size[0] = train_size[1];
+			retrain_size[1] = retrain_size[0]+4;
+			retrain_size[2] = _facts.size();
+		}
+		
+		C45TreeIterator bocuk = null;
+		DecisionTree bocukTree = null;
+		boolean tree_read = true;
+		if (input_tree == 0) {	
+			// read the matching facts from file
+			try {
+				bocuk = (C45TreeIterator)read(directory, build_file);
+				bocukTree = (DecisionTree)read(directory, tree_file);
+			} catch (Exception e) {
+				System.out.println("EXCEPTION:  Could not read the tree "+ e);
+				e.printStackTrace();
+				tree_read = false;
+			}
+//			catch (FileNotFoundException e) {
+//				// TODO Auto-generated catch block
+//				e.printStackTrace();
+//			} catch (IOException e) {
+//				// TODO Auto-generated catch block
+//				e.printStackTrace();
+//			} catch (ClassNotFoundException e) {
+//				// TODO Auto-generated catch block
+//				e.printStackTrace();
+//			}
+		}	
+		if (input_tree > 0 || !tree_read || bocukTree==null || bocuk==null) {
+			if (!tree_read)	System.out.println("EXCEPTION:  Could not read the tree so training ");
+			
+			bocuk = builder(simple, obj);
+			ArrayList<Fact> training_list = new ArrayList<Fact>(train_size[1]-train_size[0]+1);
+			for (int i= train_size[0]; i<train_size[1]; i++)
+				training_list.add(_facts.get(i));
+			/* find the matching facts */
+			bocukTree = train(bocuk, obj, training_list);
+			
+			write_tree = true;
+		}
+		
+		
+		if (bocukTree == null)	{
+			System.out.println("No decision tree found to process returning  ");
+			return;
+		} else {
+			System.out.println("!!My matching facts: \n"+ Util.ntimes("\n", 3));
+			for (TreeNode obj_node : bocuk.getMatchingFacts().keySet())
+				System.out.println("* "+obj_node.toString(1)+ " => "+bocuk.getMatchingFacts().get(obj_node) );
+			
+			if (print_tree) 
+				System.out.println("INPUT TREE: "+ bocukTree.toString(bocuk.getMatchingFacts()));
+		}
+		
+		if (retrain_tree) {
+			// retrain a tree
+			
+			ArrayList<Fact> re_training_list = new ArrayList<Fact>(retrain_size[1]-retrain_size[0]);
+			for (int i= retrain_size[0]; i<retrain_size[1]; i++)
+				re_training_list.add(_facts.get(i));
+			bocukTree = retrain(bocuk, bocukTree, re_training_list);
+			
+			write_tree = false;
+		}
+		
+		List<Integer> test_result = null;
+		if (test_tree)	{
+			System.out.println("TEST"+ input_tree);
+			//test_result = evaluation_test(bocuk, bocukTree, 0, bocuk.getNum_fact_trained());
+			//test_result = evaluation_test(bocuk, bocukTree, (int)(bocuk.getNum_fact_trained()*3/5), (int)(bocuk.getNum_fact_trained()*4/5));
+			
+			test_result = evaluation(bocuk, bocukTree, _facts.subList(test_size[0], test_size[1]));
+			System.out.print("Test results: \tMistakes "+ test_result.get(0));
+			System.out.print("\tCorrects "+ test_result.get(1));
+			System.out.println("\t Unknown "+ test_result.get(2) +" OF "+ (test_size[1]-test_size[0]) + " facts" );
+			if (print_tree)
+				System.out.println("TESTED TREE: "+ bocukTree.toString(bocuk.getMatchingFacts()));
+		}
+		
+		if (retrain_tree) {
+			// retrain a tree
+			
+			ArrayList<Fact> re_training_list = new ArrayList<Fact>(retrain_size[2]-retrain_size[1]);
+			for (int i= retrain_size[1]; i<retrain_size[2]; i++)
+				re_training_list.add(_facts.get(i));
+			bocukTree = retrain(bocuk, bocukTree, re_training_list);
+			
+			write_tree = false;
+		}
+		
+		if (test_tree)	{
+			System.out.println("TEST2"+ input_tree);
+			//test_result = evaluation_test(bocuk, bocukTree, 0, bocuk.getNum_fact_trained());
+			//test_result = evaluation_test(bocuk, bocukTree, (int)(bocuk.getNum_fact_trained()*3/5), (int)(bocuk.getNum_fact_trained()*4/5));
+			
+			test_result = evaluation(bocuk, bocukTree, _facts.subList(test_size[0], test_size[1]));
+			System.out.print("Test results: \tMistakes "+ test_result.get(0));
+			System.out.print("\tCorrects "+ test_result.get(1));
+			System.out.println("\t Unknown "+ test_result.get(2) +" OF "+ (test_size[1]-test_size[0]) + " facts" );
+			if (print_tree)
+				System.out.println("TESTED TREE: "+ bocukTree.toString(bocuk.getMatchingFacts()));
+		}
+		
+		if (write_tree) {
+			write(bocuk, directory, build_file);
+			write(bocukTree, directory, tree_file);
+		}
+
+
+		/* create the drl */
+		if (test_rules) { 
+			int max_rules = -1; // no limit print all
+			rules(bocuk, bocukTree, drlFile, max_rules); 
+		}
+		
+		
+		if (parse_w_drools) {
+			drools(drlFile, my_objects);
+		}
+
+	}
+	
+	public static C45TreeIterator builder(WorkingMemory simple, Object emptyObject) {
+
+		long st = System.currentTimeMillis();
+		String target_attr = ObjectReader.getTargetAnnotation(emptyObject.getClass());
+		
+		List<String> workingAttributes= ObjectReader.getWorkingAttributes(emptyObject.getClass());
+		
+		C45TreeBuilder c45_build = new C45TreeBuilder();
+		c45_build.setTarget(target_attr);
+		for (String attr : workingAttributes) {
+			c45_build.addAttribute(attr);
+			c45_build.addDomain(simple.getDomain(attr));
+		}
+		C45TreeIterator bocuk = new C45TreeIterator(c45_build);
+		
+		//bocuk.init(target_attr, workingAttributes);
+		
+		long build_time = System.currentTimeMillis();
+		System.out.println("\nTime to builder " + (build_time-st));
+		
+		return bocuk;	
+	}
+	
+	public static DecisionTree train(C45TreeIterator builder, Object emptyObject, ArrayList<Fact> facts) {
+								 // String drlfile, , int max_rules
+		long st = System.currentTimeMillis();
+		DecisionTree bocuksTree = builder.build_to_iterate(emptyObject.getClass(), facts);
+		
+		long train_time = System.currentTimeMillis();
+		System.out.println("\nTime to train_decision_tree " + (train_time-st));
+	
+		return bocuksTree;
+	}
+	
+	public static DecisionTree retrain(C45TreeIterator builder, DecisionTree dt, ArrayList<Fact> facts) {
+		 // String drlfile, , int max_rules
+		long st = System.currentTimeMillis();
+		DecisionTree bocuksTree = builder.re_build(dt, facts);
+
+		long retrain_time = System.currentTimeMillis();
+		System.out.println("\nTime to train_decision_tree " + (retrain_time-st));
+
+		return bocuksTree;
+	}
+	
+	public static void write(Object tree, String outputdirectory, String file) {
+		 // String drlfile, , int max_rules
+
+		long st = System.currentTimeMillis();
+
+		DecisionTreeSerializer.write(tree, outputdirectory+file);
+		
+		long write_time = System.currentTimeMillis();
+		System.out.println("Time to write_decision_tree " + (write_time-st) + "\n" );
+	
+		
+	}
+	
+	public static Object read(String directory, String file) throws Exception {
+		long st = System.currentTimeMillis();
+		
+		Object obj = DecisionTreeSerializer.read(directory+file);
+		
+		long read_time = System.currentTimeMillis();
+		System.out.println("Time to read_" + file+" "+(read_time-st) + "\n" );
+		
+		return obj;
+//		try {
+//			
+//		} catch (Exception e) {
+//			// train from scratch
+//			e.printStackTrace();
+//		}
+//return null;
+		
+		
+	}
+	
+
+	public static List<Integer> evaluation(C45TreeIterator builder, DecisionTree tree, List<Fact> facts) {
+		long st = System.currentTimeMillis();
+		
+		List<Integer> evaluation = builder.test(tree, facts); //builder.getFacts().subList(first_f, last_f));
+		long test_time = System.currentTimeMillis();
+		System.out.println("Time to test_decision_tree " + (test_time-st) + "\n" );
+		
+		return evaluation;
+
+	}
+	
+	
+	public static void rules (C45TreeIterator builder, DecisionTree tree, String drlfile, int max_rules) {
+		long st = System.currentTimeMillis();
+
+		RulePrinter my_printer  = new RulePrinter(builder.getNum_fact_trained());
+		if (max_rules >0)
+			my_printer.setMax_num_rules(max_rules);
+		boolean sort_via_rank = true;
+		boolean print = true;
+		my_printer.printer(tree, sort_via_rank, print);
+		my_printer.write2file("examples", "src/rules/examples/" + drlfile);
+		
+		long print_time = System.currentTimeMillis();
+		System.out.println("Time to print_rules " + (print_time-st) + "\n" );
+	}
+	public static void drools(String drlFile, List<Object> my_objects) throws Exception{
+		/* parse the drl */
+		
+		
+		//read in the source 
+		// TODO give an exception of the file does not exist
+		final Reader source = new InputStreamReader(Golf.class
+				.getResourceAsStream(drlFile));
+
+		final PackageBuilder builder = new PackageBuilder();
+
+		//this will parse and compile in one step
+		builder.addPackageFromDrl(source);
+
+		// Check the builder for errors
+		if (builder.hasErrors()) {
+			System.out.println(builder.getErrors().toString());
+			throw new RuntimeException("Unable to compile \"" + drlFile + "\".");
+		}
+		//get the compiled package (which is serializable)
+		final Package pkg = builder.getPackage();
+
+		//add the package to a rulebase (deploy the rule package).
+		final RuleBase ruleBase = RuleBaseFactory.newRuleBase();
+		ruleBase.addPackage(pkg);
+
+		boolean load_to_drools = false;
+		if (load_to_drools) {
+			/* feeding the object to Drools working memory */
+			final StatefulSession session = ruleBase.newStatefulSession();
+			session.addEventListener(new DebugAgendaEventListener());
+			session.addEventListener(new DebugWorkingMemoryEventListener());
+
+			final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger(
+					session);
+			logger.setFileName("log/golf");
+
+			Iterator<Object> it_obj = my_objects.iterator();
+			while (it_obj.hasNext()) {
+				Object obj = it_obj.next();
+
+				//System.out.println("Object " + obj);
+
+				try {
+					session.insert(obj);
+
+				} catch (Exception e) {
+					System.out
+							.println("Inserting element " + obj + " and " + e);
+				}
+			}
+
+			session.fireAllRules();
+
+			//        Iterator<Object> my_it = session.iterateObjects();
+			//        
+			//        while(my_it.hasNext()) {
+			//        	Object o = my_it.next();
+			//        	//System.out.println("Object " + o);
+			//        }
+			logger.writeToDisk();
+
+			session.dispose();
+		}
+		System.out.println("Happy ending");
+	}
+}
+
+

Modified: labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationTest.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationTest.java	2008-04-24 20:20:49 UTC (rev 19714)
+++ labs/jbossrules/contrib/machinelearning/dt_examples/src/java/examples/IterationTest.java	2008-04-24 20:24:01 UTC (rev 19715)
@@ -67,10 +67,10 @@
 			return;
 		}
 		
-		//int input_tree = 1;		// train a new tree with some part of the facts
-		int input_tree = 2;	// read from file
+		int input_tree = 1;		// train a new tree with some part of the facts
+		//int input_tree = 2;	// read from file
 		int[] test_size = {0, _facts.size()}, train_size = {0, _facts.size()}, retrain_size = {0, _facts.size()};
-		boolean retrain_tree = true, test_tree = true, print_tree = true, write_tree = false, test_rules = false, parse_w_drools = false;
+		boolean retrain_tree = false, test_tree = true, print_tree = true, write_tree = true, test_rules = false, parse_w_drools = false;
 		switch (obj_type) {
 		case 1:
 			test_size[1] = 2;
@@ -79,7 +79,7 @@
 			break;
 		default:
 			test_size[1] = 14;
-			train_size[1] = 7;
+			//train_size[1] = 7;
 			retrain_size[0] = train_size[1];
 			retrain_size[1] = 14;
 		}




More information about the jboss-svn-commits mailing list