[jboss-svn-commits] JBL Code SVN: r20815 - in labs/jbossrules/contrib/machinelearning/4.0.x: drools-core/src/main/java/org/drools/learner/builder and 3 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Thu Jun 26 12:17:29 EDT 2008
Author: gizil
Date: 2008-06-26 12:17:29 -0400 (Thu, 26 Jun 2008)
New Revision: 20815
Added:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/NodeValue.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Path.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeMerger.java
Modified:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/GainRatio.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/LoggerFactory.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
Log:
decision tree merger-> merging the output trees of the bagging and the boosting
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -59,9 +59,13 @@
this.target = parentTree.getTargetDomain();
this.attrsToClassify = new ArrayList<Domain>(parentTree.getAttrDomains().size()-1);
for (Domain attr_domain : parentTree.getAttrDomains()) {
- if (!attr_domain.getFName().equals(exceptDomain.getFName()))
+ if (attr_domain.isNotJustSelected(exceptDomain))
this.attrsToClassify.add(attr_domain);
}
+// System.out.print("New tree ");
+// for (Domain d:attrsToClassify)
+// System.out.print("d: "+d);
+// System.out.println("");
//Collections.sort(this.attrsToClassify, new Comparator<Domain>()); // compare the domains by the name
}
@@ -77,7 +81,11 @@
public HashMap<String, ArrayList<Field>> getAttrRelationMap() {
return obj_schema.getAttrRelationMap();
}
-
+
+ public int getId() {
+ return id;
+ }
+
public void setID(int i) {
this.id = i;
}
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -0,0 +1,141 @@
+package org.drools.learner;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Stack;
+
+
+public class DecisionTreeVisitor {
+
+ private Stack<NodeValue> nodes;
+
+ private HashMap<Integer, Path> paths;
+
+ private int num_paths_found;
+
+ public DecisionTreeVisitor() {
+ /* most important */
+ this.nodes = new Stack<NodeValue>();
+ this.paths = new HashMap<Integer, Path>();
+ num_paths_found = 0;
+ }
+
+ public void visit(DecisionTree dt) {
+ dfs(dt.getRoot(), dt.getId());
+// if (!nodes.empty())
+// nodes.pop(); // POP the root
+
+ }
+ // memory optimized
+ public void visit2(DecisionTree dt) {
+ dfs2(dt.getRoot(), dt.getId());
+ //if (!nodes.empty())
+ nodes.pop(); // POP the root
+
+ }
+
+ private void dfs(TreeNode my_node, int tree_id) {
+ //System.out.println("How many guys there of "+my_node.getDomain().getName() +" : "+my_node.getDomain().getValues().size());
+
+ if (my_node instanceof LeafNode) {
+ NodeValue leaf_value = new NodeValue(my_node);
+ leaf_value.setValue(((LeafNode) my_node).getCategory()); //getValue(null));
+ nodes.push(leaf_value);
+ Path p = spitPath(nodes);
+ p.setTreeId(tree_id);
+ num_paths_found ++;
+ if (!paths.containsKey(p.hashCode())) {
+ paths.put(p.hashCode(), p);
+ }
+ nodes.pop();
+ return;
+ }
+
+ for (Object attributeValue : my_node.getChildrenKeys()) {
+ //System.out.println("Domain: "+ my_node.getDomain().getName() + " the value:"+ attributeValue);
+ NodeValue node_value = new NodeValue(my_node);
+ node_value.setValue(attributeValue);
+ nodes.push(node_value);
+ TreeNode child = my_node.getChild(attributeValue);
+ dfs(child, tree_id);
+ nodes.pop();
+ }
+ return;
+ }
+ // memory optimized
+ private void dfs2(TreeNode my_node, int tree_id) {
+ //System.out.println("How many guys there of "+my_node.getDomain().getName() +" : "+my_node.getDomain().getValues().size());
+ NodeValue node_value = new NodeValue(my_node);
+ nodes.push(node_value);
+ if (my_node instanceof LeafNode) {
+ //NodeValue leaf_value = new NodeValue(my_node);
+ node_value.setValue(((LeafNode) my_node).getCategory()); //getValue(null));
+ //nodes.push(leaf_value);
+ //paths.add(getPath(nodes)); // if i can spit the rule here it would work
+ Path p = spitPath(nodes);
+ p.setTreeId(tree_id);
+
+ if (!paths.containsKey(p.hashCode())) {
+ paths.put(p.hashCode(), p);
+ }
+ return;
+ }
+
+ for (Object attributeValue : my_node.getChildrenKeys()) {
+ //System.out.println("Domain: "+ my_node.getDomain().getName() + " the value:"+ attributeValue);
+ node_value.setValue(attributeValue);
+ TreeNode child = my_node.getChild(attributeValue);
+ dfs2(child, tree_id);
+ nodes.pop();
+ }
+ return;
+ }
+
+ public int getNumPaths() {
+ return paths.size();
+ }
+
+ public int getNumPathsFound() {
+ return num_paths_found;
+ }
+
+ public Collection<Path> getPathList() {
+ return paths.values();
+ }
+
+ private Path spitPath(Stack<NodeValue> nodes) {
+ //, Stack<NodeValue> leaves // if more than one leaf
+ Path newPath = new Path(nodes.size());// (nodes, leaves) //if more than one leaf
+ //newRule.setObjectClass(this.getRuleClass());
+ Iterator<NodeValue> it = nodes.iterator();
+ while (it.hasNext()) {
+
+ NodeValue current = it.next();
+ if (it.hasNext()) {
+ newPath.addStep(current);
+ } else {
+ newPath.setStats(current);
+ }
+ }
+
+// if (slog.debug() != null) {
+// slog.debug().log("\n"+newPath.hashCode()+ " : "+ newPath + "\n");
+// }
+ return newPath;
+ }
+
+
+
+// private ArrayList<NodeValue> getPath(Stack<NodeValue> nodes) {
+// //, Stack<NodeValue> leaves // if more than one leaf
+// ArrayList<NodeValue> path = new ArrayList<NodeValue> (nodes.size());
+// Iterator<NodeValue> it = nodes.iterator();
+// while (it.hasNext()) {
+// NodeValue current = it.next();
+// path.add(current);
+// }
+// return path;
+// }
+
+}
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTreeVisitor.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -155,8 +155,14 @@
}
+ public boolean isNotJustSelected(Domain exceptDomain) {
+ if (this.objKlass.equals(exceptDomain.getObjKlass()))
+ return !this.getFName().equals(exceptDomain.getFName());
+ return true;
+ }
+
public int hashCode() {
- return fName.hashCode() ^ fCategories.hashCode(); // TODO
+ return objKlass.hashCode() ^ fName.hashCode() ^ fCategories.hashCode(); // TODO
}
public String toString() {
@@ -167,12 +173,5 @@
return sb.toString();
}
-
-// public DataType getDataType() {
-// return this.dataType;
-// }
-// public void setDataType(DataType data_type) {
-// this.dataType = data_type;
-// }
}
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/NodeValue.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/NodeValue.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/NodeValue.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -0,0 +1,104 @@
+package org.drools.learner;
+
+public class NodeValue { //implements RuleNode {
+
+ //private static final Logger flog = LoggerFactory.getFileLogger(NodeValue.class, LogLevel.ERROR, Util.log_file);
+
+ private TreeNode node;
+ private Object nodeValue; // should it be Attribute???
+
+ public NodeValue(TreeNode n) {
+ this.node = n;
+ }
+ public String getFReference() {
+ return node.getDomain().getFReferenceName();
+ }
+
+ public String getFName() {
+// String full_name = node.getDomain().getFName();
+// String fname = full_name.substring(full_name.lastIndexOf('@')+1, full_name.length());
+ return node.getDomain().getFName() ;
+ }
+
+ public Object getValue() {
+ return nodeValue;
+ }
+ public void setValue(Object category) {
+ this.nodeValue = category;
+
+ }
+
+ public TreeNode getNode() {
+ return node;
+ }
+
+ @Override
+ public int hashCode() {
+ String hash = stringCode();
+ return hash.hashCode();
+
+ }
+
+ public String stringCode() {
+ return node.getDomain().getObjKlass().getName()+"." + node.getDomain().getFName()+"."+nodeValue;
+ }
+
+ public String toString() {
+
+ String fName = this.getFName();//object class name
+ Class<?> node_obj = node.getDomain().getObjKlass();
+
+ String value;
+ if (node.getDomain().getFType() == String.class)
+ value = "\""+nodeValue+ "\"";
+ else
+ value = nodeValue + "";
+
+ if (node.getDomain().isCategorical())
+ return fName + " == "+ value;
+ else {
+
+ int size = node.getDomain().getCategoryCount()-1;
+ //System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
+
+ int idx = size;
+ if (nodeValue instanceof Number) {
+ for (; idx>=0; idx--) {
+ Object categoryValue = node.getDomain().getCategory(idx);
+ if (nodeValue instanceof Comparable && categoryValue instanceof Comparable) {
+ // TODO ask this to daniel???
+// if (Util.DEBUG_RULE_PRINTER) {
+// System.out.println("NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
+// }
+ if ( AttributeValueComparator.instance.compare(nodeValue, categoryValue) == 0 ) {
+ break;
+ }
+ } else {
+ System.out.println("Fuck not comparable NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
+ System.exit(0);
+ }
+
+ }
+ } else {
+ /* TODO implement the String setting */
+ System.out.println("Fuck not number:"+ nodeValue+ " c-"+nodeValue.getClass());
+ System.exit(0);
+
+ }
+
+ if (idx == 0)
+ return fName + " <= "+ value;
+ else if (idx == size)
+ // if the category is the last one that the rule is domain.name > category(last-1)
+ return fName+ " > "+ node.getDomain().getCategory(size-1);
+ else {
+ //return node.getDomain().getCategory(idx) + " < " + fName+ " <= "+ node.getDomain().getCategory(idx+1);
+ // Why drools does not support category(idx) < domain.name <= category(idx+1)
+ //flog.debug("value "+ value + "=====?????"+ node.getDomain().getCategory(idx+1));
+
+ return fName+ " <= "+ value; // node.getDomain().getCategory(idx+1);
+ }
+ }
+ }
+
+}
\ No newline at end of file
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/NodeValue.java
___________________________________________________________________
Name: svn:eol-style
+ native
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Path.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Path.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Path.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -0,0 +1,160 @@
+package org.drools.learner;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+
+public class Path {
+
+ private int code;
+ private Class<?> attr_obj; // object class name
+ private ArrayList<NodeValue> conditions;
+
+ private NodeValue action;
+/* action nodeValue has the information of
+ private double rank; // matching ratio
+ private double num_classified_instances;// number of instances matching that rule
+ private double info_mea;
+ */
+ private int treeId; // id of the tree that the path belongs
+
+// private int id; // unique id, need a unique name in the drl file
+//
+
+ public Path(int numCond) {
+ conditions = new ArrayList<NodeValue>(numCond);
+ code = 0;
+ }
+
+
+ public void addStep(NodeValue current) {
+ NodeValue nv = new NodeValue(current.getNode());
+ nv.setValue(current.getValue());
+ updateHashCode(nv);
+ conditions.add(nv);
+ }
+
+ public void setStats(NodeValue current) {
+ action = new NodeValue(current.getNode());
+ action.setValue(current.getValue());
+
+ updateHashCode(action);
+ //this.setNumClassifiedInstances(((LeafNode)current.getNode()).getNumClassification()); // only the leaf node case
+ }
+
+ public Iterator<NodeValue> getConditionIterator() {
+ return conditions.iterator();
+ }
+
+ public NodeValue getAction() {
+ return action;
+ }
+
+ public void updateHashCode(NodeValue nv) {
+ code = code + nv.hashCode() << 6;
+ }
+ public void setObjectClass(Class<?> obj) {
+ attr_obj= obj;
+ }
+ public String getObjectClassName() {
+ return attr_obj.getSimpleName();
+ }
+
+ public int getTreeId() {
+ return treeId;
+ }
+
+ public void setTreeId(int id) {
+ treeId = id;
+ }
+// private int getId() {
+// return id;
+// }
+//
+// public void setId(int id) {
+// this.id= id;
+// }
+
+// private void setNumClassifiedInstances(double dataSize) {
+// this.num_classified_instances = dataSize;
+// }
+ public double getNumClassified() {
+ return ((LeafNode)this.action.getNode()).getNumMatch();
+ }
+
+// public void setRank(double r) {
+// this.rank = r;
+// }
+ public double getRank() {
+ return ((LeafNode)this.action.getNode()).getRank(); //this.rank;
+ }
+
+ public double getInfoMea() {
+ return ((LeafNode)this.action.getNode()).getInfoMea();
+ }
+
+ public static Comparator<Path> getPathRankComparator() {
+ return new PathComparator();
+ }
+
+ private static class PathComparator implements Comparator<Path>{
+ // this will sort from best rank to least rank
+ public int compare(Path r1, Path r2) {
+ if (r1.getRank() < r2.getRank())
+ return 1; // normally -1
+ else if (r1.getRank() > r2.getRank())
+ return -1; // normally 1
+ else
+ return 0;
+ }
+ }
+/* WHAT TODO ???
+ public static Comparator<Path> getInfoComparator() {
+ return new PathInfoComparator();
+ }
+
+ private static class PathInfoComparator implements Comparator<Path>{
+ // this will sort from best rank to least rank
+ public int compare(Path p1, Path p2) {
+// COMPLEXITY++;
+// if (n1.hashCode() == n2.hashCode())
+// return 0;
+ double p1_mea = p1.getInfoMea();
+ double p2_mea = p2.getInfoMea();
+// switch (INFO_MEA) {
+// case 4: /* 4 - ranked gain ration/
+// p1_mea = p1.getRank()*p1_mea;
+// p2_mea = p2.getRank()*p2_mea;
+// break;
+// }
+
+ //if a node with the same domain exist at the same depth
+ //COMPLEXITY++;
+ if (p1_mea < p2_mea)
+ return 1; // inverted=>must be -1
+ //COMPLEXITY++;
+ if (p1_mea > p2_mea)
+ return -1; // inverted=>must be 1
+ //COMPLEXITY++; //else {
+ return 0;
+ }
+ }
+*/
+ @Override
+ public int hashCode() {
+ return code;
+ }
+
+ public String toString() {
+ StringBuffer out_bf = new StringBuffer();
+
+ for (NodeValue c:conditions) {
+ //out_bf.append(c.stringCode() +" - ");
+ out_bf.append(c +", ");
+ }
+ //out_bf.append(" => "+action.stringCode());
+ out_bf.append(" => "+action);
+ return out_bf.toString();
+
+ }
+}
\ No newline at end of file
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Path.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -17,7 +17,7 @@
* gain:
* gainRatio
*/
- private double rank, gain, gainRatio;
+ private double rank, infoMea;
// Number of all instances matching at that node
private double num_matching_instances;
@@ -60,6 +60,14 @@
return children.get(attr_key);
}
+ public double getInfoMea() {
+ return infoMea;
+ }
+
+ public void setInfoMea(double mea) {
+ this.infoMea = mea;
+ }
+
public Object voteFor(Instance i) {
final Object attr_value = i.getAttrValue(this.domain.getFReferenceName());
final Object category = domain.getCategoryOf(attr_value);
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/AdaBoostBuilder.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -26,10 +26,12 @@
private ArrayList<DecisionTree> forest;
private ArrayList<Double> classifier_accuracy;
- //private Learner trainer;
+ private DecisionTreeMerger merger;
+
public AdaBoostBuilder() {
//this.trainer = _trainer;
+ merger = new DecisionTreeMerger();
}
public void build(Memory mem, Learner _trainer) {
@@ -132,8 +134,10 @@
else {
- if (slog.stat() != null)
+ if (slog.stat() != null) {
+ slog.stat().log("\n Boosting ends: ");
slog.stat().log("All instances classified correctly TERMINATE, forest size:"+i+ "\n");
+ }
// What to do here??
FOREST_SIZE = i;
classifier_accuracy.add(10.0); // TODO add a very big number
@@ -143,12 +147,18 @@
forest.add(dt);
+ // the DecisionTreeMerger will visit the decision tree and add the paths that have not been seen yet to the list
+ merger.add(dt);
if (slog.stat() !=null)
slog.stat().stat(".");
}
// TODO how to compute a best tree from the forest
+ DecisionTree best = merger.getBest();
+ if (best == null)
+ best = forest.get(0);
+
_trainer.setBestTree(forest.get(0));
//this.c45 = dt;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -12,6 +12,7 @@
import org.drools.learner.eval.InformationContainer;
import org.drools.learner.eval.InstDistribution;
import org.drools.learner.tools.FeatureNotSupported;
+import org.drools.learner.tools.Util;
public class C45Learner extends Learner{
@@ -44,6 +45,7 @@
(double)this.getDataSize()/* total size of data fed to dt*/);
classifiedNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
classifiedNode.setNumClassification(data_stats.getSum()); //num of classified instances at the leaf node
+ //classifiedNode.setInfoMea(mea)
return classifiedNode;
}
@@ -58,7 +60,7 @@
(double)this.getDataSize() /* total size of data fed to dt*/);
noAttributeLeftNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
noAttributeLeftNode.setNumClassification(data_stats.getVoteFor(winner)); //num of classified instances at the leaf node
-
+ //noAttributeLeftNode.setInfoMea(best_attr_eval.attribute_eval);
/* we need to know how many guys cannot be classified and who these guys are */
data_stats.missClassifiedInstances(missclassified_data);
@@ -72,14 +74,14 @@
chooser.chooseAttribute(best_attr_eval, data_stats, attribute_domains);
Domain node_domain = best_attr_eval.domain;
-
-
- //flog.debug(Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
+ if (slog.debug() != null)
+ slog.debug().log("\n"+Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
currentNode.setRank((double)data_stats.getSum()/
(double)this.getDataSize() /* total size of data fed to dt*/);
+ currentNode.setInfoMea(best_attr_eval.attribute_eval);
Hashtable<Object, InstDistribution> filtered_stats = null;
@@ -95,6 +97,8 @@
for (int c = 0; c<node_domain.getCategoryCount(); c++) {
/* split the last two class at the same time */
Object category = node_domain.getCategory(c);
+ if (slog.debug() != null)
+ slog.debug().log("{"+ node_domain +":"+category+ "}");
/* list of domains except the choosen one (&target domain)*/
DecisionTree child_dt = new DecisionTree(dt, node_domain);
@@ -105,6 +109,7 @@
majorityNode.setRank(-1.0); //it does not classify any instance
majorityNode.setNumMatch(0);
majorityNode.setNumClassification(0);
+ //currentNode.setInfoMea(best_attr_eval.attribute_eval);
currentNode.putNode(category, majorityNode);
} else {
TreeNode newNode = train(child_dt, filtered_stats.get(category));//, attributeNames_copy
Added: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeMerger.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeMerger.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeMerger.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -0,0 +1,80 @@
+package org.drools.learner.builder;
+
+import java.util.ArrayList;
+import java.util.Collections;
+
+import org.drools.learner.DecisionTree;
+import org.drools.learner.DecisionTreeVisitor;
+import org.drools.learner.Path;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
+
+public class DecisionTreeMerger {
+
+ //private static SimpleLogger flog = LoggerFactory.getFileLogger(DecisionTreeMerger.class, SimpleLogger.WARN, "rules_gizil2");
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(DecisionTreeMerger.class, SimpleLogger.WARN);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(DecisionTreeMerger.class, SimpleLogger.WARN);
+
+
+ private DecisionTreeVisitor visitor;
+
+ private ArrayList<Path> sorted_paths;
+
+ public DecisionTreeMerger() {
+ visitor = new DecisionTreeVisitor();
+
+ }
+ public void add(DecisionTree dt) {
+ visitor.visit(dt);
+ }
+
+ public void add2(DecisionTree dt) {
+ visitor.visit2(dt);
+ }
+
+ public DecisionTree getBest() {
+ sortPaths();
+ printSortedPaths();
+ //System.exit(0);
+ return null;
+ }
+
+ public void sortPaths() {
+ sorted_paths = new ArrayList<Path>(visitor.getPathList());
+ Collections.sort(sorted_paths, Path.getPathRankComparator());
+ }
+
+ public int getNumPaths() {
+ return visitor.getNumPaths();
+ }
+
+ public int getNumPathsfound() {
+ return visitor.getNumPathsFound();
+ }
+ public void printPaths() {
+ if (flog.warn() != null) {
+ for (Path p : visitor.getPathList()) {
+ flog.warn().log(p.hashCode()+ "-"+p.getRank()+" : "+ p + "\n");
+ }
+
+ }
+
+ }
+
+
+ public void printSortedPaths() {
+
+ if (flog.warn() != null) {
+ flog.warn().log("Sorted paths: Total num of paths "+getNumPathsfound()+" num paths different"+getNumPaths()+ "\n");
+ slog.warn().log("Total num of paths "+getNumPathsfound()+" num paths different"+getNumPaths()+ "\n");
+ for (Path p: sorted_paths) {
+ flog.warn().log(p.hashCode()+ "-"+p.getTreeId()+"-"+p.getRank()+" : "+ p + "\n");
+ slog.warn().log(p.hashCode()+ "-"+p.getTreeId()+"-"+p.getRank()+" : "+ p + "\n");
+ }
+
+ }
+
+
+
+ }
+}
Property changes on: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeMerger.java
___________________________________________________________________
Name: svn:eol-style
+ native
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -12,7 +12,7 @@
public class ForestBuilder implements DecisionTreeBuilder{
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
- private static SimpleLogger slog = LoggerFactory.getSysOutLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(ForestBuilder.class, SimpleLogger.DEBUG);
private TreeAlgo algorithm = TreeAlgo.BAG; // default bagging, TODO boosting
@@ -23,8 +23,11 @@
private ArrayList<DecisionTree> forest;
//private Learner trainer;
+ private DecisionTreeMerger merger;
+
public ForestBuilder() {
//this.trainer = _trainer;
+ merger = new DecisionTreeMerger();
}
public void build(Memory mem, Learner _trainer) {
@@ -53,22 +56,33 @@
else
bag = Util.bag_wo_rep(tree_capacity, N);
- InstanceList working_instances = class_instances.getInstances(bag);
+ InstanceList working_instances = class_instances.getInstances(bag);
+
+ if (slog.debug() != null)
+ slog.debug().log("\n"+"Training a tree"+"\n");
DecisionTree dt = _trainer.train_tree(working_instances);
-
+ if (slog.debug() != null)
+ slog.debug().log("\n"+"the end"+ "\n");
dt.setID(i);
forest.add(dt);
+ // the DecisionTreeMerger will visit the decision tree and add the paths that have not been seen yet to the list
+ merger.add(dt);
if (slog.stat() !=null)
slog.stat().stat(".");
}
+
+ //System.exit(0);
// TODO how to compute a best tree from the forest
- _trainer.setBestTree(forest.get(0));
+ DecisionTree best = merger.getBest();
+ if (best == null)
+ best = forest.get(0);
+ _trainer.setBestTree(best);// forest.get(0));
//this.c45 = dt;
}
-
+
public TreeAlgo getTreeAlgo() {
return algorithm; //TreeAlgo.BAG; // default
}
@@ -77,3 +91,5 @@
return forest;
}
}
+
+
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -15,6 +15,7 @@
public abstract class Learner {
protected static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(Learner.class, SimpleLogger.DEFAULT_LEVEL);
+ protected static SimpleLogger slog = LoggerFactory.getSysOutLogger(Learner.class, SimpleLogger.DEBUG);
public static enum DomainAlgo { CATEGORICAL, QUANTITATIVE }
public static DomainAlgo DEFAULT_DOMAIN = DomainAlgo.QUANTITATIVE;
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/AttributeChooser.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -3,9 +3,15 @@
import java.util.List;
import org.drools.learner.Domain;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
public class AttributeChooser {
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(AttributeChooser.class, SimpleLogger.WARN);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(AttributeChooser.class, SimpleLogger.WARN);
+
+
private Heuristic function;
public AttributeChooser(Heuristic _function) {
@@ -33,19 +39,27 @@
attribute_eval = function.getEval(attr_domain);//data_eval - function.info_attr(insts_by_target, attr_domain);
container.attribute_eval = attribute_eval;
container.domain = attr_domain;
+ if (slog.warn() != null)
+ slog.warn().log("CatAttribute: " + container.domain + " the gain: " + attribute_eval + " greatest "+ greatestEval+ "\n");
+
+
} else {
// the continuous domain
attribute_eval = function.getEval_cont(attr_domain);
-
-// attr_domain = function.getDomain();
-// sorted_instances = visitor.getSortedInstances();
-
container.attribute_eval = attribute_eval;
container.domain = function.getDomain();
container.sorted_data = function.getSortedInstances();
+ if (slog.warn() != null)
+ slog.warn().log("ContAttribute: " + container.domain + " the gain: " + attribute_eval + " greatest "+ greatestEval+ "\n");
+
+// attr_domain = function.getDomain();
+// sorted_instances = visitor.getSortedInstances();
+
+
}
-// flog.debug("Attribute: " + attr_domain + " the gain: " + gain);
+ if (slog.warn() != null)
+ slog.warn().log("Attribute: " + container.domain + " the gain: " + attribute_eval + " greatest "+ greatestEval+ "\n");
if (attribute_eval > greatestEval) {// TODO implement a comparator
greatestEval = attribute_eval;
best.domain = container.domain;
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/GainRatio.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/GainRatio.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/GainRatio.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -16,6 +16,8 @@
double info_gain = super.data_eval - Entropy.calc_info_attr(insts_by_attr);
double split_info = GainRatio.split_info(insts_by_attr);
+
+ System.err.println("(GainRatio) info_gain = "+ info_gain + "/"+ split_info);
return info_gain /split_info;
}
@@ -42,17 +44,22 @@
private static double split_info( CondClassDistribution instances_by_attr) {
//Collection<Object> attributeValues = instances_by_attr.getAttributes();
double data_size = instances_by_attr.getTotal();
- double sum = 0.0;
- if (data_size>0)
+ double sum = 1.0;
+ if (data_size>0) {
for (int attr_idx = 0; attr_idx < instances_by_attr.getNumCondClasses(); attr_idx++) {
Object attr_category = instances_by_attr.getCondClass(attr_idx);
double num_in_attr = instances_by_attr.getTotal_AttrCategory(attr_category);
- if (num_in_attr > 0) {
+ if (num_in_attr > 0.0) {
double prob = num_in_attr / data_size;
sum -= prob * Util.log2(prob);
}
}
+ } else {
+ System.err.println("????? data_size = "+ data_size);
+ System.exit(0);
+ }
+
//flog.debug("\n == "+sum);
return sum;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -13,6 +13,9 @@
public ArrayList<Instance> sorted_data;
public InformationContainer() {
+ domain = null;
+ attribute_eval = 0.0;
+ sorted_data = null;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -46,7 +46,7 @@
data_size += inst.getWeight();
Object target_key = inst.getAttrValue(tName);
- super.change(target_key, inst.getWeight()); // add one for vote for the target value : target_key
+ super.change(target_key, inst.getWeight()); // add inst.getWeight() vote for the target value of the instance : target_key
//super.change(attr_sum, inst.getWeight()); // ?????
this.addSupporter(target_key, inst);
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/LoggerFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/LoggerFactory.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/LoggerFactory.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -10,7 +10,8 @@
public class LoggerFactory {
- private static BufferedWriter buffer = new BufferedWriter(new StringWriter());
+ private static StringWriter str_writer = new StringWriter();
+ private static BufferedWriter buffer = new BufferedWriter(str_writer);
private static HashMap<Class<?>, SimpleLogger> fileLoggers = new HashMap<Class<?>, SimpleLogger>();
public static SimpleLogger getUniqueFileLogger(Class<?> klass, int level) {
@@ -43,12 +44,23 @@
int last_slash = file_sign.lastIndexOf('/');
String file_name = file_sign.substring(0, last_slash+1) + directory+file_sign.substring(last_slash)+"."+directory;
- System.out.println(file_name);
+ System.out.println("file "+ file_name+ " logged ");
+
PrintWriter writer;
try {
writer = new PrintWriter (new BufferedWriter (new FileWriter (file_name)));
- writer.write(buffer.toString());
+ buffer.flush();
+// System.out.println("LOG1: "+ str_writer.getBuffer());
+// System.out.println("LOG2: "+ str_writer.getBuffer().toString());
+
+ writer.write(str_writer.getBuffer().toString()); //str_buffer.getBuffer());
+ str_writer.close();
+ buffer.close();
+ writer.flush();
+ // Close the BufferedWriter object and the underlying
+ // StringWriter object.
+ writer.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -12,19 +12,18 @@
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
-import java.util.Stack;
-import org.drools.learner.AttributeValueComparator;
import org.drools.learner.DecisionTree;
+import org.drools.learner.DecisionTreeVisitor;
import org.drools.learner.LeafNode;
-import org.drools.learner.TreeNode;
+import org.drools.learner.NodeValue;
+import org.drools.learner.Path;
public class RulePrinter {
private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(RulePrinter.class, SimpleLogger.WARN);
private static SimpleLogger slog = LoggerFactory.getSysOutLogger(RulePrinter.class, SimpleLogger.WARN);
-
public static Reader readRules(DecisionTree learned_dt) {
RulePrinter my_printer = new RulePrinter(); //bocuk.getNum_fact_trained()
@@ -42,30 +41,28 @@
private Class<?> rule_clazz;
- private Stack<NodeValue> nodes;
-
+ private DecisionTreeVisitor visitor;
private ArrayList<Rule> rules;
- //private ArrayList<String> ruleText;
-
private int bound_on_num_rules;
private double num_instances;
private HashMap<String, ArrayList<Field>> attrRelations;
+ private HashMap<String, Class<?>> importList; // TODO init
//private NumberComparator nComparator;
public RulePrinter() {
- /* most important */
- this.nodes = new Stack<NodeValue>();
+// /* most important */
+// this.nodes = new Stack<NodeValue>();
+ this.visitor = new DecisionTreeVisitor();
+
this.rules = new ArrayList<Rule>();
//ruleText = new ArrayList<String>();
this.bound_on_num_rules = -1;
this.num_instances = -1.0d;
-
- //this.nComparator = new NumberComparator();
}
public Class<?> getRuleClass() {
@@ -91,83 +88,60 @@
public void print(DecisionTree dt, boolean sort) {//, PrintStream object
this.rule_clazz = dt.getObjClass();
this.attrRelations = dt.getAttrRelationMap();
+ this.importList = new HashMap<String, Class<?>>();
this.num_instances = dt.getRoot().getNumMatch();
- dfs(dt.getRoot());
- if (sort)
- Collections.sort(rules, Rule.getRankComparator());
- }
-
- private void dfs(TreeNode my_node) {
- //System.out.println("How many guys there of "+my_node.getDomain().getName() +" : "+my_node.getDomain().getValues().size());
+ visitor.visit(dt);
- NodeValue node_value = new NodeValue(my_node);
- nodes.push(node_value);
-
- if (my_node instanceof LeafNode) {
- node_value.setNodeValue(((LeafNode) my_node).getCategory()); //getValue(null));
- //ruleText.add(print(nodes));
- //rule_list.add(spit(nodes));
- // what if more than one condition (more than one leafNode)
-
- Rule newRule = spitRule(nodes);
+ for (Path p: visitor.getPathList()) {
+ Rule newRule = createRule(p);
newRule.setId(rules.size());
rules.add(newRule);
- return;
}
- for (Object attributeValue : my_node.getChildrenKeys()) {
- //System.out.println("Domain: "+ my_node.getDomain().getName() + " the value:"+ attributeValue);
- node_value.setNodeValue(attributeValue);
- TreeNode child = my_node.getChild(attributeValue);
- dfs(child);
- nodes.pop();
- }
- return;
+ if (sort)
+ Collections.sort(rules, Rule.getRankComparator());
}
-
- private Rule spitRule(Stack<NodeValue> nodes) {
+
+ private Rule createRule(Path p) {
//, Stack<NodeValue> leaves // if more than one leaf
- Rule newRule = new Rule(this.getRuleClass(), nodes.size());// (nodes, leaves) //if more than one leaf
- Iterator<NodeValue> it = nodes.iterator();
+ Rule newRule = new Rule(this.getRuleClass());// (nodes, leaves) //if more than one leaf
+
+ Iterator<NodeValue> it = p.getConditionIterator();
while (it.hasNext()) {
-
NodeValue current = it.next();
if (slog.error() != null)
slog.error().log("NodeValue " +current + "\n");
if (slog.error() != null)
slog.error().log("attrRelations [" +attrRelations.size() + "]\n");
- if (it.hasNext()) {
- ArrayList<Field> nodeRelations = attrRelations.get(current.getFReference());
-
- if (nodeRelations == null || nodeRelations.isEmpty()) {
- // this a direct child add
- newRule.addConditionToMain(current);
-
- } else {
-
- for (Field f:nodeRelations) {
- // i need the class that the field belongs to boooook
- String referenceOfCondition = Util.getDecReference(f);
- if (slog.error() != null)
- slog.error().log("[" +referenceOfCondition + "],");
- }
+// if (it.hasNext()) {
+ ArrayList<Field> nodeRelations = attrRelations.get(current.getFReference());
+
+ if (nodeRelations == null || nodeRelations.isEmpty()) {
+ // this a direct child add
+ newRule.addConditionToMain(current);
+ } else {
+ for (Field f:nodeRelations) {
+ // i need the class that the field belongs to boooook
+ String referenceOfCondition = Util.getDecReference(f);
if (slog.error() != null)
- slog.error().log("\n");
- newRule.processNodeValue(current, nodeRelations, 0, 1); //int condition_or_action = condition = 1
+ slog.error().log("[" +referenceOfCondition + "],");
}
- } else {
-
- ArrayList<Field> nodeRelations = attrRelations.get(current.getFReference());
- if (nodeRelations == null || nodeRelations.isEmpty()) {
- // this a direct child add to reference to the main guy
- newRule.addActionToMain(current);
- } else {
- newRule.processNodeValue(current, nodeRelations, 0, 2); //int condition_or_action = action = 2
- }
+ if (slog.error() != null)
+ slog.error().log("\n");
+ newRule.processNodeValue(current, nodeRelations, 0, 1); //int condition_or_action = condition = 1
}
+// } else { }
}
+ NodeValue action_node = p.getAction();
+ ArrayList<Field> nodeRelations = attrRelations.get(action_node.getFReference());
+ if (nodeRelations == null || nodeRelations.isEmpty()) {
+ // this a direct child add to reference to the main guy
+ newRule.addActionToMain(action_node);
+ } else {
+ newRule.processNodeValue(action_node, nodeRelations, 0, 2); //int condition_or_action = action = 2
+ }
return newRule;
}
@@ -209,26 +183,19 @@
}
public String write2string(){//String packageName) {
- StringBuffer outputBuffer = new StringBuffer();
+ StringBuffer introBuffer = new StringBuffer();
+ StringBuffer bodyBuffer = new StringBuffer();
String packageName = this.getRuleClass().getPackage().getName();
//log.debug("Package name: "+ packageName);
if (packageName != null)
- outputBuffer.append("package " + packageName +";\n\n");
-
+ introBuffer.append("package " + packageName +";\n\n");
else {
//TODO throw exception
//flog.error("RulePrinter write2string packageName="+packageName);
}
-
-// flog.debug(new Object() {
-// public String toString() {
-// String out = "//Num of rules " +rules.size()+"\n //this.getBoundOnNumRules() "+ getBoundOnNumRules();
-// return out;
-// }
-// });
int total_num_facts=0;
int i = 0, active_i = 0;
@@ -238,33 +205,51 @@
if (Util.ONLY_ACTIVE_RULES) {
if (rule.getRank() >= 0) {
active_i++;
-// if (Util.DEBUG_RULE_PRINTER) {
-// System.out.println("//Active rules " +i + " write to drl \n"+ rule +"\n");
-// }
- outputBuffer.append(rule.toString());
- outputBuffer.append("\n");
+ bodyBuffer.append(rule.toString());
+ bodyBuffer.append("\n");
+
+ introBuffer.append(getImports(rule.getDeclarationIt()));
}
} else {
if (rule.getRank() >= 0) {
active_i++;
}
-// if (Util.DEBUG_RULE_PRINTER) {
-// System.out.println("//rule " +i + " write to drl \n"+ rule +"\n");
-// }
- outputBuffer.append(rule.toString());
- outputBuffer.append("\n");
+ bodyBuffer.append(rule.toString());
+ bodyBuffer.append("\n");
+ introBuffer.append(getImports(rule.getDeclarationIt()));
}
total_num_facts += rule.getNumClassifiedInstances();
if (this.getBoundOnNumRules()>0 && i >= this.getBoundOnNumRules())
break;
}
- outputBuffer.append("//THE END: Total number of facts correctly classified= "+ total_num_facts + " over "+ this.getNumInstances());
- outputBuffer.append("\n//with " + active_i + " number of rules over "+i+" total number of rules ");
- outputBuffer.append("\n"); // EOF
+ bodyBuffer.append("//THE END: Total number of facts correctly classified= "+ total_num_facts + " over "+ this.getNumInstances());
+ bodyBuffer.append("\n//with " + active_i + " number of rules over "+i+" total number of rules ");
+ bodyBuffer.append("\n"); // EOF
+
+ StringBuffer outputBuffer = new StringBuffer();
+ outputBuffer.append(introBuffer);
+ outputBuffer.append("\n");
+ outputBuffer.append(bodyBuffer);
return outputBuffer.toString();
}
+
+ private StringBuffer getImports(Iterator<Declaration> declarationIt) {
+ StringBuffer importBuffer = new StringBuffer();
+ while (declarationIt.hasNext()) {
+ Declaration dec = declarationIt.next();
+ String name = dec.getDeclaringTypeName();
+ if (!importList.containsKey(name)) {
+ importList.put(name, dec.getDeclaringType());
+ //String import_name = dec.getDeclaringType().getName().replaceAll("$", ".");
+ String import_name = dec.getDeclaringType().getName().replace('$', '.');
+ importBuffer.append("import "+ import_name);
+ importBuffer.append("\n");
+ }
+ }
+ return importBuffer;
+ }
}
class Rule {
@@ -285,7 +270,7 @@
// private String referenceToMain = main_obj.getName()+"0";
private int main_obj_id = 0;
- Rule(Class<?> obj, int numCond) {
+ Rule(Class<?> obj) {
num_declarations = 0;
rule_decs = new ArrayList<Declaration>(1); //new ArrayList<Declaration>(1);
declarationMap = new HashMap<String, Integer>(1);
@@ -297,6 +282,10 @@
actions = new ArrayList<AttrReference>(1);
}
+ public Iterator<Declaration> getDeclarationIt() {
+ return rule_decs.iterator();
+ }
+
public void addConditionToMain(NodeValue current) {
rule_decs.get(main_obj_id).addCondition(current);
}
@@ -414,8 +403,6 @@
this.setNumClassifiedInstances(((LeafNode)current.getNode()).getNumClassification());
}
-
-
public String getObjectClassName() {
return main_obj.getSimpleName();
}
@@ -470,7 +457,7 @@
for (int dec_i =rule_decs.size()-1; dec_i>=0; dec_i--) {
Declaration d = rule_decs.get(dec_i);
String obj_ref = d.getSymbol(); //"$"+this.getObjectClassName().substring(0, 1).toLowerCase();
- sb_out.append("\n\t\t "+obj_ref+" : "+d.getDeclaringFTypeCanonicalName()+"("+ "");
+ sb_out.append("\n\t\t "+obj_ref+" : "+d.getDeclaringTypeName()+"("+ "");
Iterator<NodeValue> dec_it = d.getConditionIt();
while (dec_it.hasNext()) {
NodeValue cond = dec_it.next();
@@ -483,9 +470,9 @@
sb_out.append(ref.toString() + ", ");
}
sb_out.delete(sb_out.length()-2, sb_out.length()-1);
- sb_out.append(")\n");
+ sb_out.append(")");
}
-
+ sb_out.append("\n");
StringBuffer sb_action = new StringBuffer("");
StringBuffer sb_field = new StringBuffer("");
StringBuffer sb_expected_value = new StringBuffer("");
@@ -507,63 +494,7 @@
return sb_out.toString();
}
-
-
-// public String toString_() {
-// /*
-// rule "Good Bye"
-// dialect "java"
-// when
-// $m:Message( status == Message.GOODBYE)
-// then
-// System.out.println( "[getLabel()] Expected value (" + $c.getLabel() + "), Classified as (False)");
-// end
-// */
-// //"rule \"#"+getId()+" "+decision+" rank:"+rank+"\" \n";
-// StringBuffer sb_out = new StringBuffer("");
-// String obj_ref = "$"+this.getObjectClassName().substring(0, 1).toLowerCase();
-//
-// sb_out.append("\t when");
-// sb_out.append("\n\t\t "+obj_ref+":"+this.getObjectClassName() +"("+ "");
-//// for (NodeValue cond: conditions) {
-//// sb_out.append(cond.toString() + ", ");
-//// }
-//
-// StringBuffer sb_action = new StringBuffer("");
-// StringBuffer sb_field = new StringBuffer("");
-// StringBuffer sb_expected_value = new StringBuffer("");
-// for (NodeValue act: actions) {
-// // if the query is on a field then i have to get its value during in the rule 'cause it might be private
-// if (!act.getNode().getDomain().isArtificial())
-// sb_out.append(obj_ref+ "_"+act.getFName() + " : "+act.getFName()+", ");
-//
-// sb_action.append(act.getNodeValue() + " , ");
-// if (!act.getNode().getDomain().isArtificial())
-// sb_field.append(act.getFName() + "");
-// else
-// sb_field.append(act.getFName() + "()");
-//
-//
-// if (!act.getNode().getDomain().isArtificial())
-// sb_expected_value.append(obj_ref+ "_"+act.getFName());//reading the value by the reference of $o_fieldname
-// else
-// sb_expected_value.append(obj_ref+ "."+act.getFName() + "()");// reading the value from the object $o.function()
-//
-// }
-// sb_action.delete(sb_action.length()-3, sb_action.length()-1);
-// sb_out.delete(sb_out.length()-2, sb_out.length()-1);
-// sb_out.append(")\n");
-// sb_out.append("\t then ");
-// sb_out.append("\n\t\t System.out.println(\"["+sb_field.toString()+ "] Expected value (\" + "+ sb_expected_value.toString()+ " + \"), Classified as ("+sb_action.toString()+")\");\n");
-// if (getRank() <0)
-// sb_out.append("\n\t\t System.out.println(\"But no matching fact found = DOES not fire on\");\n");
-//
-// sb_out.insert(0, "rule \"#"+getId()+" "+sb_field.toString()+ "= "+sb_action.toString()+" classifying "+this.getNumClassifiedInstances()+" num of facts with rank:"+getRank() +"\" \n");
-// sb_out.append("end\n");
-//
-// return sb_out.toString();
-// }
-
+
public static Comparator<Rule> getRankComparator() {
return new RuleComparator();
}
@@ -602,14 +533,20 @@
return dec_ref;
}
- public String getDeclaringFTypeCanonicalName() {
+ public Class<?> getDeclaringType() {
+ return declared_obj;
+ }
+ public String getDeclaringTypeName() {
return declared_obj.getSimpleName();
}
public void addCondition(NodeValue current) {
- NodeValue nv = new NodeValue(current.getNode());
- nv.setNodeValue(current.getNodeValue());
- conditions.add(nv);
+ /* TODO check do u need to copy the node value
+ NodeValue nv = new NodeValue(current.getNode());
+ nv.setNodeValue(current.getNodeValue());
+ conditions.add(nv);
+ /**/
+ conditions.add(current);
}
public void addActionReference(AttrReference aRef) {//NodeValue current) {
@@ -681,14 +618,17 @@
_fName = Util.getFieldName(_fName);
}
fName = _fName;
+ /* TODO is the copying really necessary?
real_value = new NodeValue(v.getNode());
real_value.setNodeValue(v.getNodeValue());
+ */
+ real_value = v;
}
public Object getVariableName() {
return "$target_label";
}
public Object getValue() {
- return real_value.getNodeValue();
+ return real_value.getValue();
}
public String getFName() {
return fName;
@@ -699,151 +639,3 @@
}
}
-
-class NodeValue { //implements RuleNode {
-
- //private static final Logger flog = LoggerFactory.getFileLogger(NodeValue.class, LogLevel.ERROR, Util.log_file);
-
- private TreeNode node;
- private Object nodeValue; // should it be Attribute???
-
- public NodeValue(TreeNode n) {
- this.node = n;
- }
- public String getFReference() {
- return node.getDomain().getFReferenceName();
- }
-
- public String getFName() {
-// String full_name = node.getDomain().getFName();
-// String fname = full_name.substring(full_name.lastIndexOf('@')+1, full_name.length());
- return node.getDomain().getFName() ;
- }
-
- public Object getNodeValue() {
- return nodeValue;
- }
- public void setNodeValue(Object category) {
- this.nodeValue = category;
-
- }
-
- public TreeNode getNode() {
- return node;
- }
-
- public String toString() {
-
- String fName = this.getFName();//object class name
- Class<?> node_obj = node.getDomain().getObjKlass();
-
- String value;
- if (node.getDomain().getFType() == String.class)
- value = "\""+nodeValue+ "\"";
- else
- value = nodeValue + "";
-
- if (node.getDomain().isCategorical())
- return fName + " == "+ value;
- else {
-
- int size = node.getDomain().getCategoryCount()-1;
- //System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
-
- int idx = size;
- if (nodeValue instanceof Number) {
- for (; idx>=0; idx--) {
- Object categoryValue = node.getDomain().getCategory(idx);
- if (nodeValue instanceof Comparable && categoryValue instanceof Comparable) {
- // TODO ask this to daniel???
-// if (Util.DEBUG_RULE_PRINTER) {
-// System.out.println("NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
-// }
- if ( AttributeValueComparator.instance.compare(nodeValue, categoryValue) == 0 ) {
- break;
- }
- } else {
- System.out.println("Fuck not comparable NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
- System.exit(0);
- }
-
- }
- } else {
- /* TODO implement the String setting */
- System.out.println("Fuck not number:"+ nodeValue+ " c-"+nodeValue.getClass());
- System.exit(0);
-
- }
-
- if (idx == 0)
- return fName + " <= "+ value;
- else if (idx == size)
- // if the category is the last one that the rule is domain.name > category(last-1)
- return fName+ " > "+ node.getDomain().getCategory(size-1);
- else {
- //return node.getDomain().getCategory(idx) + " < " + fName+ " <= "+ node.getDomain().getCategory(idx+1);
- // Why drools does not support category(idx) < domain.name <= category(idx+1)
- //flog.debug("value "+ value + "=====?????"+ node.getDomain().getCategory(idx+1));
-
- return fName+ " <= "+ value; // node.getDomain().getCategory(idx+1);
- }
- }
- }
-
- public String toString_() {
-
- String fName = this.getFName();//node.getDomain().getFName();
- String value;
- if (node.getDomain().getFType() == String.class)
- value = "\""+nodeValue+ "\"";
- else
- value = nodeValue + "";
-
- if (node.getDomain().isCategorical())
- return fName + " == "+ value;
- else {
-
- int size = node.getDomain().getCategoryCount()-1;
- //System.out.println("How many guys there of "+node.getDomain().getName() +" and the value "+nodeValue+" : "+size);
-
- int idx = size;
- if (nodeValue instanceof Number) {
- for (; idx>=0; idx--) {
- Object categoryValue = node.getDomain().getCategory(idx);
- if (nodeValue instanceof Comparable && categoryValue instanceof Comparable) {
- // TODO ask this to daniel???
-// if (Util.DEBUG_RULE_PRINTER) {
-// System.out.println("NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
-// }
- if ( AttributeValueComparator.instance.compare(nodeValue, categoryValue) == 0 ) {
- break;
- }
- } else {
- System.out.println("Fuck not comparable NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
- System.exit(0);
- }
-
- }
- } else {
- /* TODO implement the String setting */
- System.out.println("Fuck not number:"+ nodeValue+ " c-"+nodeValue.getClass());
- System.exit(0);
-
- }
-
- if (idx == 0)
- return fName + " <= "+ value;
- else if (idx == size)
- // if the category is the last one that the rule is domain.name > category(last-1)
- return fName+ " > "+ node.getDomain().getCategory(size-1);
- else {
- //return node.getDomain().getCategory(idx) + " < " + fName+ " <= "+ node.getDomain().getCategory(idx+1);
- // Why drools does not support category(idx) < domain.name <= category(idx+1)
- //flog.debug("value "+ value + "=====?????"+ node.getDomain().getCategory(idx+1));
-
- return fName+ " <= "+ value; // node.getDomain().getCategory(idx+1);
- }
- }
- }
-
-}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -37,7 +37,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 111;
+ DecisionTree decision_tree; int ALGO = 322;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -40,7 +40,7 @@
session.insert(r);
}
- DecisionTree decision_tree; int ALGO = 111;
+ DecisionTree decision_tree; int ALGO = 322;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-06-26 15:18:52 UTC (rev 20814)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/TriangleExample.java 2008-06-26 16:17:29 UTC (rev 20815)
@@ -36,7 +36,7 @@
}
// instantiate a learner for a specific object class and pass session to train
- DecisionTree decision_tree; int ALGO = 321;
+ DecisionTree decision_tree; int ALGO = 322;
/*
* Single 1xx, Bag 2xx, Boost 3xx
* ID3 x1x, C45 x2x
More information about the jboss-svn-commits
mailing list