[jboss-svn-commits] JBL Code SVN: r20385 - in labs/jbossrules/contrib/machinelearning/4.0.x: drools-core and 6 other directories.
jboss-svn-commits at lists.jboss.org
jboss-svn-commits at lists.jboss.org
Mon Jun 9 14:37:43 EDT 2008
Author: gizil
Date: 2008-06-09 14:37:43 -0400 (Mon, 09 Jun 2008)
New Revision: 20385
Added:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
Removed:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45Example.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
Modified:
labs/jbossrules/contrib/machinelearning/4.0.x/drools-compiler/src/main/java/org/drools/compiler/PackageBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/pom.xml
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/audit/WorkingMemoryFileLogger.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Instance.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/InstanceList.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Memory.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/QuantitativeDomain.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Schema.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Stats.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Tester.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/ClassDistribution.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CondClassDistribution.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Entropy.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/FeatureNotSupported.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/NumberComparator.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/ObjectFactory.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/Util.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Car.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45ExampleFromDrl.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java
Log:
backup before adding the Heuristic class
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-compiler/src/main/java/org/drools/compiler/PackageBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-compiler/src/main/java/org/drools/compiler/PackageBuilder.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-compiler/src/main/java/org/drools/compiler/PackageBuilder.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -42,9 +42,10 @@
import org.drools.lang.descr.PackageDescr;
import org.drools.lang.descr.QueryDescr;
import org.drools.lang.descr.RuleDescr;
-import org.drools.learner.builder.ID3Learner;
-import org.drools.learner.builder.Learner;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.tools.LoggerFactory;
import org.drools.learner.tools.RulePrinter;
+import org.drools.learner.tools.Util;
import org.drools.rule.Package;
import org.drools.rule.Rule;
import org.drools.rule.builder.RuleBuildContext;
@@ -211,23 +212,25 @@
}
/**
- * Load a learner to load its rule package from DRL source AND/OR AST.
+ * Load the decision tree and its rules.
*
* @param learner
* @throws DroolsParserException
* @throws IOException
*/
- public void addPackageFromLearner(final Learner learner) throws DroolsParserException, IOException {
- Reader reader = RulePrinter.readRules(learner);
+ public void addPackageFromTree(final DecisionTree dt) throws DroolsParserException, IOException {
+ Reader reader = RulePrinter.readRules(dt);
/* final DrlParser parser = new DrlParser();
final PackageDescr pkg = parser.parse( reader );
this.results.addAll( parser.getErrors() );
addPackage( pkg );
*/
-
+ // save the logger
+ LoggerFactory.dump_buffer(Util.DRL_DIRECTORY+dt.getSignature(), "log");
this.addPackageFromDrl(reader);
}
+
/**
* Add a ruleflow (.rt) asset to this package.
*/
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/pom.xml
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/pom.xml 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/pom.xml 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,5 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
@@ -32,6 +32,11 @@
<artifactId>mvel</artifactId>
</dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <version>1.2.9</version>
+ </dependency>
</dependencies>
<build>
@@ -58,4 +63,4 @@
</build>
-</project>
+</project>
\ No newline at end of file
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/audit/WorkingMemoryFileLogger.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/audit/WorkingMemoryFileLogger.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/audit/WorkingMemoryFileLogger.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -24,12 +24,6 @@
import org.drools.WorkingMemoryEventManager;
import org.drools.audit.event.LogEvent;
-import org.drools.event.AfterFunctionRemovedEvent;
-import org.drools.event.AfterRuleBaseLockedEvent;
-import org.drools.event.AfterRuleBaseUnlockedEvent;
-import org.drools.event.BeforeFunctionRemovedEvent;
-import org.drools.event.BeforeRuleBaseLockedEvent;
-import org.drools.event.BeforeRuleBaseUnlockedEvent;
import com.thoughtworks.xstream.XStream;
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/DecisionTree.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -3,11 +3,12 @@
import java.util.ArrayList;
import java.util.Collections;
-import org.drools.learner.tools.Util;
-
public class DecisionTree {
+ //private static final Logger log = LoggerFactory.getSysOutLogger(LogLevel.ERROR);
+ //private static final Logger flog = LoggerFactory.getFileLogger(DecisionTree.class, LogLevel.ERROR, Util.log_file);
+
private Class<?> obj_clazz;
/* the target attribute */
@@ -21,21 +22,20 @@
// The id of the tree in the forest
private int id;
+ private String execution_signature;
public long FACTS_READ = 0;
public DecisionTree(Schema inst_schema, String _target) {
this.obj_clazz = inst_schema.getObjectClass();
- if (Util.DEBUG_DECISION_TREE) {
- System.out.println("The target attribute: "+ _target);
- }
+ //flog.debug("The target attribute: "+ _target);
+
this.target = inst_schema.getAttrDomain(_target);
this.attrsToClassify = new ArrayList<Domain>(inst_schema.getAttrNames().size()-1);
for (String attr_name : inst_schema.getAttrNames()) {
if (!attr_name.equals(_target)) {
- if (Util.DEBUG_DECISION_TREE) {
- System.out.println("Adding the attribute: "+ attr_name);
- }
+ //flog.debug("Adding the attribute: "+ attr_name);
+
this.attrsToClassify.add(inst_schema.getAttrDomain(attr_name));
}
}
@@ -84,12 +84,18 @@
return this.getRoot().voteFor(i);
}
+ public void setSignature(String executionSignature) {
+ execution_signature = executionSignature;
+ }
+
+ public String getSignature() {
+ return execution_signature;
+ }
+
@Override
public String toString() {
String out = "Facts scanned " + FACTS_READ + "\n";
return out + root.toString();
- }
+ }
-
-
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Domain.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -4,7 +4,7 @@
import java.util.Collections;
public class Domain {
- private boolean categorical, fixed;
+ private boolean categorical, fixed, artificial;
private String fName;
private Class<?> fType;// not sure if necessary
protected ArrayList<Object> fCategories;
@@ -14,6 +14,7 @@
this.fType = _type;
this.categorical = true; // BY DEFAULT, it is categorical
+ this.artificial = false; // BY DEFAULT, it is a real field, if it is artificial it means there is no field exist but there is method which computes the value
this.fCategories = new ArrayList<Object>(2);
@@ -36,6 +37,10 @@
public String getFName() {
return this.fName;
}
+
+ protected void setFName(String name) {
+ this.fName = name;
+ }
public void setFixed(boolean _fixed) {
this.fixed = _fixed;
}
@@ -134,11 +139,20 @@
}
public String toString() {
- String out = fName + "";
+ StringBuffer sb = new StringBuffer(fName + "");
// for (Object v: fValues) {
-// out += "-" + v;
+// sb.append("-" + v);
// }
- return out;
+ return sb.toString();
+
}
+
+ // if the field is a artificial field
+ public void setArtificial(boolean b) {
+ this.artificial = b;
+ }
+ public boolean isArtificial() {
+ return this.artificial;
+ }
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Instance.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Instance.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Instance.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -5,12 +5,19 @@
public class Instance {
private HashMap<String, Attribute> attributes;
+ private double weight = 1.0;
public Instance() {
this.attributes = new HashMap<String, Attribute>(); // TODO should i set a size, HOW?
}
+ public void setWeight(double w) {
+ weight = w;
+ }
+ public double getWeight() {
+ return weight;
+ }
public void setAttr(String _name, Object _value) {
Attribute f_attr = new Attribute();
@@ -31,12 +38,11 @@
}
public String toString() {
- String out = this.hashCode() + " ";
- for (String key: attributes.keySet())
- {
- out += key +"="+attributes.get(key)+", ";
+ StringBuffer sb = new StringBuffer(this.hashCode() + " ");
+ for (String key: attributes.keySet()) {
+ sb.append(key +"="+attributes.get(key)+", ");
}
- return out;
+ return sb.toString();
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/InstanceList.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/InstanceList.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/InstanceList.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -7,12 +7,14 @@
import org.drools.WorkingMemory;
import org.drools.base.ClassFieldExtractor;
import org.drools.common.InternalWorkingMemory;
+import org.drools.spi.Extractor;
public class InstanceList {
private Schema schema;
private ArrayList<Instance> instances;
+ private ArrayList<Double> weights;
public InstanceList(Schema _schema) {
this.schema = _schema;
@@ -30,8 +32,8 @@
for (String f_name : schema.getAttrNames()) {
Domain f_domain = schema.getAttrDomain(f_name);
- ClassFieldExtractor f_extractor = schema.getAttrExtractor(f_name);
-
+ //ClassFieldExtractor f_extractor = schema.getAttrExtractor(f_name);
+ Extractor f_extractor = schema.getAttrExtractor(f_name);//Label
/* from WorkingMemoryLogger, private String extractDeclarations(final Activation activation, final WorkingMemory workingMemory) {
* you can cast the WorkingMemory
* final Object value = declaration.getValue( (InternalWorkingMemory) workingMemory, handleImpl.getObject() );
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/LeafNode.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -5,7 +5,7 @@
public class LeafNode extends TreeNode {
private Object targetCategory;
- private int num_intances_classified;
+ private double num_intances_classified;
public LeafNode(Domain targetDomain, Object value) {
super(targetDomain);
@@ -21,11 +21,11 @@
return targetCategory;
}
- public void setNumClassification(int size) {
+ public void setNumClassification(double size) {
this.num_intances_classified= size;
}
- public int getNumClassification() {
+ public double getNumClassification() {
return this.num_intances_classified;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Memory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Memory.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Memory.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -4,13 +4,14 @@
import java.util.Iterator;
import org.drools.WorkingMemory;
+import org.drools.learner.builder.Learner.DomainAlgo;
import org.drools.learner.tools.FeatureNotSupported;
public class Memory {
- // TODO pass a list of classes, and get all th eobject from that class
- public static Memory createFromWorkingMemory(WorkingMemory _session, Class<?> clazz, int domain_type) throws FeatureNotSupported {
+ // TODO pass a list of classes, and get all the object from that class
+ public static Memory createFromWorkingMemory(WorkingMemory _session, Class<?> clazz, DomainAlgo domain_type) throws FeatureNotSupported {
// if mem == null
Memory mem = new Memory();
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/QuantitativeDomain.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/QuantitativeDomain.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/QuantitativeDomain.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -40,6 +40,9 @@
}
+ public void setFName(String fname) {
+ super.setFName(fname);
+ }
public boolean isCategorical() {
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Schema.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Schema.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Schema.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -2,6 +2,7 @@
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
+import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
@@ -10,35 +11,38 @@
import org.drools.base.ClassFieldExtractor;
import org.drools.base.ClassFieldExtractorCache;
+import org.drools.learner.builder.Learner.DomainAlgo;
+import org.drools.learner.tools.ClassAnnotation;
import org.drools.learner.tools.FeatureNotSupported;
import org.drools.learner.tools.FieldAnnotation;
+import org.drools.learner.tools.PseudoFieldExtractor;
import org.drools.learner.tools.Util;
+import org.drools.spi.Extractor;
/*
* A description of a data set's attributes and their properties.
*/
public class Schema {
- public static Schema createFromClass(Class<?> clazz, int domain_type) throws FeatureNotSupported {
+ public static Schema createFromClass(Class<?> clazz, DomainAlgo domain_type) throws FeatureNotSupported {
Schema schema = new Schema(clazz);
ClassFieldExtractorCache cache = ClassFieldExtractorCache.getInstance();
+ /*
+ ArrayList<Field> element_fields = new ArrayList<Field>();
+ Util.getAllFields(clazz, element_fields);
+ */
ArrayList<Field> element_fields = new ArrayList<Field>();
- Util.getAllFields(clazz, element_fields);
- //clazz.getDeclaredFields(); //clazz.getFields();
+ /* Apperantly the getMethod function recurse on the superclasses
+ * i dont need to recurse myself
+ ArrayList<Class<?>> element_classes = new ArrayList<Class<?>>();
+ Util.getAllFields(clazz, element_fields, element_classes);
+ */
+ Util.getAllFields(clazz, element_fields); //clazz.getDeclaredFields(); //clazz.getFields();
for (Field f: element_fields) {
String f_name = f.getName();
ClassFieldExtractor f_extractor = cache.getExtractor( clazz, f_name, clazz.getClassLoader() );
schema.extractorMap.put(f_name, f_extractor);
- //f_extractor.
-
- int f_type = 0;
- if (f.getType().isPrimitive() || f.getType() == String.class) {
-
- f_type = 1;
- } else {
- f_type = 2;
- }
Annotation[] annotations = f.getAnnotations();
FieldAnnotation spec = null;
@@ -56,11 +60,11 @@
}
boolean skip = false;
switch (domain_type) {
- case Util.ID3: //ID3
+ case ID3: //ID3
if (spec.ignore() || !spec.discrete())
skip = true;
break;
- case Util.C45:
+ case C45:
if (spec.ignore())
skip = true;
break;
@@ -93,6 +97,57 @@
*/
schema.domainMap.put(f_name, fieldDomain);
}
+ /* Apperantly the getMethod function recurse on the superclasses
+ * i dont need to recurse myself
+ for (Class<?> c: element_classes) {
+ Annotation[] annotations = c.getAnnotations();
+ */
+ Annotation[] annotations = clazz.getAnnotations(); // it should get the inherited annotations
+ ClassAnnotation lab = null;
+ for (Annotation a : annotations) {
+ if (a instanceof ClassAnnotation) {
+ lab= (ClassAnnotation)a; // here it is !!!
+ break;
+ }
+ }
+
+ if (lab != null && lab.label_element() != "") {
+ // the targetting label is set, put the function that gets that value somewhere
+
+ try {
+ /* Apperantly the getMethod function recurse on the superclasses
+ * i dont need to recurse myself
+ Method m =c.getDeclaredMethod(lab.label_element(), null);
+ */
+ Method m = clazz.getMethod(lab.label_element(), null);
+
+ Domain fieldDomain = new Domain(lab.label_element(), m.getReturnType());
+ fieldDomain.setArtificial(true);
+ if (m.getReturnType() == Boolean.TYPE || m.getReturnType() == Boolean.class) { /* set discrete*/
+// fieldDomain.setCategorical(true); // BY DEFAULT it is categorical
+ fieldDomain.addCategory(Boolean.TRUE);
+ fieldDomain.addCategory(Boolean.FALSE);
+ fieldDomain.setFixed(true);
+ }
+ //else if (m.getReturnType() == String.class) { /* BY DEFAULT it is categorical*/}
+ schema.domainMap.put(lab.label_element(), fieldDomain);
+
+ Extractor m_extractor = new PseudoFieldExtractor(clazz, m);
+ //cache.getExtractor( clazz, lab.label_element(), clazz.getClassLoader() );
+ schema.extractorMap.put(lab.label_element(), m_extractor);
+ schema.clearTargets();
+ schema.addTarget(lab.label_element());
+ //break; // if the ClassAnnotation is found then stop
+
+ } catch (SecurityException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (NoSuchMethodException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+ //}
return schema;
}
@@ -101,7 +156,7 @@
private Class<?> klass;
// key: field name
- private Hashtable<String, ClassFieldExtractor> extractorMap;
+ private Hashtable<String, Extractor> extractorMap;
// key: field name
private Hashtable<String, Domain> domainMap;
@@ -110,7 +165,7 @@
public Schema(Class<?> _klass) {
this.klass = _klass;
- this.extractorMap = new Hashtable<String, ClassFieldExtractor>();
+ this.extractorMap = new Hashtable<String, Extractor>();
this.domainMap = new Hashtable<String, Domain>();
this.targets = new HashSet<String>();
}
@@ -123,6 +178,10 @@
return targets.add(_target);
}
+ public void clearTargets() {
+ targets.clear();
+ }
+
public Set<String> getAttrNames() {
return this.domainMap.keySet();
}
@@ -131,7 +190,7 @@
return this.domainMap.get(attr_name);
}
- public ClassFieldExtractor getAttrExtractor(String attr_name) {
+ public Extractor getAttrExtractor(String attr_name) {
return this.extractorMap.get(attr_name);
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Stats.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Stats.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/Stats.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,20 +1,16 @@
package org.drools.learner;
-import java.io.BufferedWriter;
-import java.io.DataOutputStream;
import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
import java.io.FileWriter;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
+import java.io.IOException;
import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.io.Writer;
import java.util.ArrayList;
-import org.drools.learner.tools.Util;
-
public class Stats {
- public static int INCORRECT = 0, CORRECT = 1, UNKNOWN = 2;
+ public static final int INCORRECT = 0, CORRECT = 1, UNKNOWN = 2;
private Class<?> stat_class;
private ArrayList<Integer> histogram;
@@ -47,46 +43,61 @@
public int getTotal() {
return total_data;
}
-
- public void print2file(String dataFile, int domain_type, int tree_set) {
+ /*
+ * fileSignature must contain the folder location
+ * by default the folder = "src/main/rules/"
+ */
+ public void print2file(String fileSignature) {
- String packageFolders = this.stat_class.getPackage().getName();
-
- String _packageNames = packageFolders.replace('.', '/');
+ //String dataFileName = "src/main/rules/"+_packageNames+"/"+ fileName;
- String fileName = (dataFile == null || dataFile == "") ? this.stat_class.getSimpleName().toLowerCase() : dataFile;
+ if (!fileSignature.endsWith(".stats"))
+ fileSignature += ".stats";
+ System.out.println("printing stats:"+ fileSignature);
-
- String suffix = Util.getFileSuffix(domain_type, tree_set);
- fileName += "_"+suffix + ".stats";
-
-
-
- String dataFileName = "src/main/rules/"+_packageNames+"/"+ fileName;
try {
- StatsPrinter.print(this, new FileOutputStream(dataFileName));
+ StatsPrinter.print(this, new FileWriter(fileSignature));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
}
}
public void print2out() {
- StatsPrinter.print(this, System.out);
+ try {
+ StatsPrinter.print(this, new PrintWriter(System.out));
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
}
+
+ public String print2string() {
+ StringWriter swr = new StringWriter();
+ try {
+ StatsPrinter.print(this, swr);
+ return swr.toString();
+ } catch (IOException e) {
+ e.printStackTrace();
+ return "";
+ }
+ }
}
class StatsPrinter {
- public static void print(Stats mystats, OutputStream os) {
- PrintWriter pwr = new PrintWriter(os);
+ //public static void print(Stats mystats, OutputStream os) {
+ public static void print(Stats mystats, Writer wr) throws IOException {
+ //PrintWriter pwr = new PrintWriter(os);
// print the statistics of the results to a file
- pwr.println("TESTING results: incorrect "+ mystats.getResult(Stats.INCORRECT));
- pwr.println("TESTING results: correct "+ mystats.getResult(Stats.CORRECT));
- pwr.println("TESTING results: unknown "+ mystats.getResult(Stats.UNKNOWN));
- pwr.println("TESTING results: Total Number "+ mystats.getTotal());
+ wr.write("TESTING results: incorrect "+ mystats.getResult(Stats.INCORRECT)+"\n");
+ wr.write("TESTING results: correct "+ mystats.getResult(Stats.CORRECT)+"\n");
+ wr.write("TESTING results: unknown "+ mystats.getResult(Stats.UNKNOWN)+"\n");
+ wr.write("TESTING results: Total Number "+ mystats.getTotal());
- pwr.flush();
+ wr.flush();
}
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/TreeNode.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -7,6 +7,10 @@
public class TreeNode {
+ //private static final Logger log = LoggerFactory.getSysOutLogger(TreeNode.class, LogLevel.ERROR);
+ //private static final Logger flog = LoggerFactory.getFileLogger(TreeNode.class);
+
+
private Domain domain;
private Hashtable<Object, TreeNode> children;
/* TODO explain
@@ -17,7 +21,7 @@
private double rank, gain, gainRatio;
// Number of all instances matching at that node
- private int num_matching_instances;
+ private double num_matching_instances;
public TreeNode(Domain domain) {
this.domain = domain;
@@ -33,11 +37,11 @@
this.rank = _rank;
}
- public void setNumMatch(int size) {
+ public void setNumMatch(double size) {
this.num_matching_instances= size;
}
- public int getNumMatch() {
+ public double getNumMatch() {
return this.num_matching_instances;
}
@@ -58,29 +62,27 @@
}
public Object voteFor(Instance i) {
- Object attr_value = i.getAttrValue(this.domain.getFName());
- Object category = domain.getCategoryOf(attr_value);
+ final Object attr_value = i.getAttrValue(this.domain.getFName());
+ final Object category = domain.getCategoryOf(attr_value);
- TreeNode my_node = this.getChild(category);
+ final TreeNode my_node = this.getChild(category);
- if (Util.DEBUG_TEST) {
- String out = "\nDomain:"+this.domain.getFName()+"->";
- for (int idx = 0; idx < this.domain.getCategoryCount(); idx++) {
- Object value = this.domain.getCategory(idx);
- out += value+"-";
- }
- out = Util.ntimes("$", 5) + out + " SEARCHING for = "+ attr_value + " in "+ this.domain.getFName();
-
- out += "\n KEYS:";
- for (Object key: this.getChildrenKeys()) {
-
- out += " "+key +"% "+this.getChild(key).getDomain() + " :";
- }
- System.out.print(out);
- System.out.print(" @myclass:"+category+ "\n");
- System.out.print(" @mynode:"+my_node+ "\n");
-
- }
+
+// flog.debug(new Object() {
+// public String toString() {
+//
+// StringBuffer sb = new StringBuffer(Util.ntimes("$", 5)+"\nDomain:"+domain.getFName()+"->");
+// for (int idx = 0; idx < domain.getCategoryCount(); idx++) {
+// sb.append(domain.getCategory(idx)+"-");
+// }
+// sb.append(" SEARCHING for = "+ attr_value + " in "+ domain.getFName());
+// sb.append("\n KEYS:");
+// for (Object key: getChildrenKeys()) {
+// sb.append(" "+key +"% "+ getChild(key).getDomain() + " :");
+// }
+// return sb.toString();
+// }
+// });
return my_node.voteFor(i);
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/C45Learner.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -11,14 +11,12 @@
import org.drools.learner.eval.InformationContainer;
import org.drools.learner.eval.InstDistribution;
import org.drools.learner.tools.FeatureNotSupported;
-import org.drools.learner.tools.Util;
public class C45Learner extends Learner{
-
public C45Learner() {
super();
- super.setDomainType(Util.C45);
+ super.setDomainAlgo(DomainAlgo.C45);
}
@@ -69,9 +67,9 @@
Entropy.chooseAttribute(best_attr_eval, data_stats, attribute_domains);
Domain node_domain = best_attr_eval.domain;
- if (Util.DEBUG_LEARNER) {
- System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
- }
+
+
+ //flog.debug(Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/DecisionTreeBuilder.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -5,6 +5,13 @@
public interface DecisionTreeBuilder {
- void build(Memory wm, Learner l);
+ //public static final int SINGLE = 1, BAG = 2, BOOST = 3;
+ public static enum TreeAlgo { SINGLE, BAG, BOOST, BOOST_K }
+ void build(Memory wm, Learner trainer);
+
+// public Learner getLearner();
+
+ public TreeAlgo getTreeAlgo();
+
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ForestBuilder.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -2,145 +2,78 @@
import java.util.ArrayList;
-import org.drools.learner.AttributeValueComparator;
import org.drools.learner.DecisionTree;
-import org.drools.learner.Domain;
-import org.drools.learner.Instance;
import org.drools.learner.InstanceList;
import org.drools.learner.Memory;
-import org.drools.learner.Stats;
-import org.drools.learner.eval.ClassDistribution;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
-public class ForestBuilder {
+public class ForestBuilder implements DecisionTreeBuilder{
+
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(ForestBuilder.class, SimpleLogger.DEFAULT_LEVEL);
- private int DOMAIN_TYPE;
- private int TREE_SET = Util.BAG; // default bagging, TODO boosting
+ private TreeAlgo algorithm = TreeAlgo.BAG; // default bagging, TODO boosting
- private static final int FOREST_SIZE = 10;
+ private static final int FOREST_SIZE = 50;
private static final double TREE_SIZE_RATIO = 0.9;
- private static final boolean BAGGING_WITH_REP = true;
+ private static final boolean WITH_REP = true;
private ArrayList<DecisionTree> forest;
+ //private Learner trainer;
- public void build(Memory mem, Learner trainer) {
- DOMAIN_TYPE = trainer.getDomainType();
+ public ForestBuilder() {
+ //this.trainer = _trainer;
+ }
+ public void build(Memory mem, Learner _trainer) {
- InstanceList class_instances = mem.getClassInstances();
- trainer.setInputData(class_instances);
+ final InstanceList class_instances = mem.getClassInstances();
+ _trainer.setInputData(class_instances);
+
if (class_instances.getTargets().size()>1) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
- System.out.println("There is more than 1 target candidates");
+ if (flog.error() !=null)
+ flog.error().log("There is more than 1 target candidates");
System.exit(0);
// TODO put the feature not supported exception || implement it
}
-
- if (Util.DEBUG_LEARNER) {
- for (Instance inst: class_instances.getInstances()) {
- System.out.println("Inst: "+ inst);
- }
- }
-
+
int N = class_instances.getSize();
int tree_capacity = (int)(TREE_SIZE_RATIO * N);
- trainer.setDataSize(tree_capacity);
+ _trainer.setDataSizePerTree(tree_capacity);
forest = new ArrayList<DecisionTree> (FOREST_SIZE);
int i = 0;
int[] bag;
while (i++ < FOREST_SIZE) {
- if (BAGGING_WITH_REP)
+ if (WITH_REP)
bag = Util.bag_w_rep(tree_capacity, N);
else
bag = Util.bag_wo_rep(tree_capacity, N);
InstanceList working_instances = class_instances.getInstances(bag);
- DecisionTree dt = trainer.train_tree(working_instances);
+ DecisionTree dt = _trainer.train_tree(working_instances);
dt.setID(i);
forest.add(dt);
- System.out.print('.');
- }
- // TODO how to compute a best tree from the forest
- trainer.setBestTree(forest.get(0));
- //this.c45 = dt;
- }
- public Object voteOn(Instance i) {
-
- Domain target_domain = forest.get(0).getTargetDomain(); // all must have the same target domain
- ClassDistribution classification = new ClassDistribution(target_domain);
+ if (slog.stat() !=null)
+ slog.stat().stat(".");
- for (int j = 0; j< forest.size() ; j ++) {
- Object vote = forest.get(j).vote(i);
- if (vote != null)
- classification.change(vote, 1);
- else {
- // TODO add an unknown value
- //classification.change(-1, 1);
- System.out.println(Util.ntimes("\n", 10)+"Unknown situation at tree: " + j + " for fact "+ i);
- System.exit(0);
- }
}
- classification.evaluateMajority();
- Object winner = classification.get_winner_class();
+ // TODO how to compute a best tree from the forest
+ _trainer.setBestTree(forest.get(0));
-
- double ratio = 0.0;
- if (classification.get_num_ideas() == 1) {
- //100 %
- ratio = 1;
- return winner;
- } else {
- int num_votes = classification.getVoteFor(winner);
- ratio = ((double) num_votes/(double) forest.size());
- // TODO if the ratio is smaller than some number => reject
- }
- return winner;
-
+ //this.c45 = dt;
}
- public void test(InstanceList data) {
-
- Stats evaluation = new Stats(data.getSchema().getObjectClass()) ; //represent.getObjClass());
- Domain targetDomain = forest.get(0).getTargetDomain();
- int i = 0;
- for (Instance instance : data.getInstances()) {
- Object forest_decision = this.voteOn(instance);
- Integer result = evaluate(targetDomain, instance, forest_decision);
-
- if (Util.DEBUG_TEST) {
- System.out.println(Util.ntimes("#\n", 1)+i+ " <START> TEST: instant="+ instance + " = target "+ result);
- } else {
- if (i%1000 ==0) System.out.print(".");
- }
- evaluation.change(result, 1);
- i ++;
- }
-
- printStats(evaluation);
+ public TreeAlgo getTreeAlgo() {
+ return algorithm; //TreeAlgo.BAG; // default
}
- private static Integer evaluate (Domain targetDomain, Instance i, Object tree_decision) {
- String targetFName = targetDomain.getFName();
-
- Object tattr_value = i.getAttrValue(targetFName);
- Object i_category = targetDomain.getCategoryOf(tattr_value);
-
- if (AttributeValueComparator.instance.compare(i_category, tree_decision) == 0) {
- return Integer.valueOf(1); //correct
- } else {
- return Integer.valueOf(0); // mistake
- }
+ public ArrayList<DecisionTree> getTrees() {
+ return forest;
}
-
- private void printStats(Stats evaluation) {
- if (Util.PRINT_STATS) {
- if (Util.DEBUG_TEST) {
- evaluation.print2out();
- }
- evaluation.print2file("", DOMAIN_TYPE, TREE_SET);
- }
- }
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/ID3Learner.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -15,7 +15,7 @@
public ID3Learner() {
super();
- super.setDomainType(Util.ID3);
+ super.setDomainAlgo(DomainAlgo.ID3);
}
@@ -62,7 +62,8 @@
* */
Domain node_domain = Entropy.chooseAttributeAsCategorical(data_stats, attribute_domains);
- //System.out.println(Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
+ if (flog.stat() !=null)
+ flog.stat().log(Util.ntimes("*", 20)+" 1st best attr: "+ node_domain);
TreeNode currentNode = new TreeNode(node_domain);
currentNode.setNumMatch(data_stats.getSum()); //num of matching instances to the leaf node
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Learner.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -9,17 +9,21 @@
import org.drools.learner.InstanceList;
import org.drools.learner.TreeNode;
import org.drools.learner.eval.InstDistribution;
-import org.drools.learner.tools.Util;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
public abstract class Learner {
+ protected static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(Learner.class, SimpleLogger.DEFAULT_LEVEL);
+
+ public static enum DomainAlgo { ID3, C45, WEIGHT }
private int data_size;
private DecisionTree best_tree;
private InstanceList input_data;
protected HashSet<Instance> missclassified_data;
- private int DOMAIN_TYPE;
+ private DomainAlgo algorithm;
protected abstract TreeNode train(DecisionTree dt, InstDistribution data_stats);
@@ -33,18 +37,15 @@
String target = this.getTargetDomain().getFName();
DecisionTree dt = new DecisionTree(input_data.getSchema(), target);
- if (Util.DEBUG_LEARNER) {
- System.out.println("Num of attributes: "+ dt.getAttrDomains().size());
- }
+ //flog.debug("Num of attributes: "+ dt.getAttrDomains().size());
+
InstDistribution stats_by_class = new InstDistribution(dt.getTargetDomain());
stats_by_class.calculateDistribution(working_instances.getInstances());
dt.FACTS_READ += working_instances.getSize();
TreeNode root = train(dt, stats_by_class);
dt.setRoot(root);
- if (Util.DEBUG_LEARNER) {
- System.out.println("Result tree\n" + dt);
- }
+ //flog.debug("Result tree\n" + dt);
return dt;
}
@@ -62,24 +63,21 @@
for (String target: input_data.getTargets()) {
dt = new DecisionTree(input_data.getSchema(), target);
- if (Util.DEBUG_LEARNER) {
- System.out.println("Num of attributes: "+ dt.getAttrDomains().size());
- }
+ //flog.debug("Num of attributes: "+ dt.getAttrDomains().size());
+
InstDistribution stats_by_class = new InstDistribution(dt.getTargetDomain());
stats_by_class.calculateDistribution(working_instances.getInstances());
dt.FACTS_READ += working_instances.getSize();
TreeNode root = train(dt, stats_by_class);
dt.setRoot(root);
- if (Util.DEBUG_LEARNER) {
- System.out.println("Result tree\n" + dt);
- }
+ //flog.debug("Result tree\n" + dt);
}
return dt;
}
- public void setDataSize(int num) {
+ public void setDataSizePerTree(int num) {
this.data_size = num;
missclassified_data = new HashSet<Instance>();
@@ -93,18 +91,21 @@
return best_tree;
}
- public int getDomainType() {
- return this.DOMAIN_TYPE;
+ public DomainAlgo getDomainAlgo() {
+ return this.algorithm;
}
- public void setDomainType(int type) {
- this.DOMAIN_TYPE = type;
+ public void setDomainAlgo(DomainAlgo type) {
+ this.algorithm = type;
}
public void setInputData(InstanceList class_instances) {
- this.input_data = class_instances;
-
+ this.input_data = class_instances;
}
+
+ public InstanceList getInputData() {
+ return input_data;
+ }
public void setBestTree(DecisionTree dt) {
this.best_tree = dt;
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/SingleTreeBuilder.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,97 +1,52 @@
package org.drools.learner.builder;
-import org.drools.learner.AttributeValueComparator;
+
import org.drools.learner.DecisionTree;
-import org.drools.learner.Domain;
-import org.drools.learner.Instance;
import org.drools.learner.InstanceList;
import org.drools.learner.Memory;
-import org.drools.learner.Stats;
-import org.drools.learner.tools.Util;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
public class SingleTreeBuilder implements DecisionTreeBuilder{
- DecisionTree one_tree;
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(SingleTreeBuilder.class, SimpleLogger.DEFAULT_LEVEL);
+ private static SimpleLogger slog = LoggerFactory.getSysOutLogger(SingleTreeBuilder.class, SimpleLogger.DEFAULT_LEVEL);
+
+ private TreeAlgo algorithm = TreeAlgo.SINGLE; // default bagging, TODO boosting
- private int DOMAIN_TYPE;
- private int TREE_SET = Util.SINGLE; // default bagging, TODO boosting
+ private DecisionTree one_tree;
+ //private Learner trainer;
+ public SingleTreeBuilder() {//Learner _trainer) {
+ //this.trainer = _trainer;
+ //dom_type = trainer.getDomainType();
+ }
+
/*
* the memory has the information
* the instances: the objects which the decision tree will work on
* the schema: the definition of the object instance
* (Class<?>) klass, String targetField, List<String> workingAttributes
*/
- public void build(Memory mem, Learner trainer) {
- DOMAIN_TYPE = trainer.getDomainType();
-
-
- InstanceList class_instances = mem.getClassInstances();
- trainer.setInputData(class_instances);
+ public void build(Memory mem, Learner _trainer) {
+ final InstanceList class_instances = mem.getClassInstances();
+ _trainer.setInputData(class_instances);
if (class_instances.getTargets().size()>1) {
//throw new FeatureNotSupported("There is more than 1 target candidates");
- System.out.println("There is more than 1 target candidates");
+ if (flog.error() !=null)
+ flog.error().log("There is more than 1 target candidates");
+
System.exit(0);
// TODO put the feature not supported exception || implement it
}
- if (Util.DEBUG_LEARNER)
- for (Instance inst: class_instances.getInstances()) {
- System.out.println("Inst: "+ inst);
- }
-
- trainer.setDataSize(class_instances.getSize());
- one_tree = trainer.train_tree(class_instances);
- trainer.setBestTree(one_tree);
+ _trainer.setDataSizePerTree(class_instances.getSize());
+ one_tree = _trainer.train_tree(class_instances);
+ _trainer.setBestTree(one_tree);
}
- /* test the entire set*/
- public void test(InstanceList data) {
- if (this.one_tree == null) {
- System.out.println("The tree is not created");
- System.exit(0);
- }
-
- if (Util.DEBUG_TEST) {
- System.out.println(Util.ntimes("\n", 2)+Util.ntimes("$", 5)+" TESTING "+Util.ntimes("\n", 2));
- }
-
- Stats evaluation = new Stats(this.one_tree.getObjClass());
- int i = 0;
- for (Instance instance : data.getInstances()) {
- Object tree_decision = this.one_tree.vote(instance);
- Integer result = evaluate(this.one_tree.getTargetDomain(), instance, tree_decision);
- if (Util.DEBUG_TEST) {
- System.out.println(Util.ntimes("#\n", 1)+i+ " <START> TEST: instant="+ instance + " = target "+ result);
- } else {
- if (i%1000 ==0) System.out.print(".");
- }
- evaluation.change(result, 1);
- i ++;
- }
- printStats(evaluation);
+ public TreeAlgo getTreeAlgo() {
+ return this.algorithm; // default
}
-
- private static Integer evaluate (Domain targetDomain, Instance i, Object tree_decision) {
- String targetFName = targetDomain.getFName();
-
- Object tattr_value = i.getAttrValue(targetFName);
- Object i_category = targetDomain.getCategoryOf(tattr_value);
-
- if (AttributeValueComparator.instance.compare(i_category, tree_decision) == 0) {
- return Integer.valueOf(1); //correct
- } else {
- return Integer.valueOf(0); // mistake
- }
- }
-
- private void printStats(Stats evaluation) {
- if (Util.PRINT_STATS) {
- if (Util.DEBUG_TEST) {
- evaluation.print2out();
- }
- evaluation.print2file("", DOMAIN_TYPE, TREE_SET);
- }
- }
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Tester.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Tester.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/builder/Tester.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -9,9 +9,11 @@
public abstract class Tester{
- private int DOMAIN_TYPE = 0, TREE_SET =0;
- public abstract void test(DecisionTreeBuilder builder, InstanceList data);
+ //private static final Logger log = LoggerFactory.getSysOutLogger(LogLevel.ERROR);
+ //private static final Logger flog = LoggerFactory.getFileLogger(Tester.class, LogLevel.ERROR, Util.log_file);
+ public abstract Stats test(InstanceList data);// String executionSignature
+
public static Integer evaluate (Domain targetDomain, Instance i, Object tree_decision) {
String targetFName = targetDomain.getFName();
@@ -25,12 +27,12 @@
}
}
- public void printStats(Stats evaluation) {
+ protected void printStats(final Stats evaluation, String executionSignature) {
if (Util.PRINT_STATS) {
- if (Util.DEBUG_TEST) {
- evaluation.print2out();
- }
- evaluation.print2file("", this.DOMAIN_TYPE, this.TREE_SET);
+// if (flog.debug() !=null)
+// flog.debug().log(evaluation.print2string());
+
+ evaluation.print2file(executionSignature);
}
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Categorizer.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -3,18 +3,19 @@
import java.util.ArrayList;
import java.util.Collections;
-
-
import org.drools.learner.Domain;
import org.drools.learner.Instance;
import org.drools.learner.InstanceComparator;
import org.drools.learner.QuantitativeDomain;
-
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
public class Categorizer {
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(Categorizer.class, SimpleLogger.DEFAULT_LEVEL);
+
private ArrayList<Instance> data;
private QuantitativeDomain splitDomain;
private Domain targetDomain;
@@ -36,7 +37,7 @@
public Categorizer(InstDistribution _data_in_class) {
//List<Instance> _instancese
- this.data = new ArrayList<Instance>(_data_in_class.getSum());
+ this.data = new ArrayList<Instance>((int)_data_in_class.getSum());
this.distribution = _data_in_class;
this.targetDomain = _data_in_class.getClassDomain();
@@ -74,16 +75,18 @@
return bD;
}
- public ArrayList<Instance> getInstances() {
+ public ArrayList<Instance> getSortedInstances() {
return data;
}
private double find_a_split(int begin_index, int size, int depth,
ClassDistribution facts_in_class) {
+
+ if (flog.debug() !=null)
+ flog.debug().log("./n");
if (data.size() <= 1 || (size-begin_index)<2) {
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("fact.size <=1 returning 0.0....");
- }
+ if (flog.warn() !=null)
+ flog.warn().log("fact.size <=1 returning 0.0....");
return 0.0;
}
// if (facts_in_class.getSum() == 0) {
@@ -91,24 +94,21 @@
// }
facts_in_class.evaluateMajority();
if (facts_in_class.get_num_ideas()==1) {
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("getNum_supported_target_classes=1 returning 0.0....");
- }
+ if (flog.warn() !=null)
+ flog.warn().log("getNum_supported_target_classes=1 returning 0.0....");
return 0.0; //?
}
if (depth == 0) {
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("depth == 0 returning 0.0....");
- }
+ if (flog.warn() !=null)
+ flog.warn().log("depth == 0 returning 0.0....");
return 0.0;
}
/* initialize the distribution */
CondClassDistribution instances_by_attr = new CondClassDistribution(binaryDomain, this.targetDomain);
- instances_by_attr.setTotal(data.size());
instances_by_attr.setDistForAttrValue(key1, facts_in_class);
- instances_by_attr.setTotal(data.size());
+ instances_by_attr.setTotal(facts_in_class.getSum());
double best_sum = 100000.0;
@@ -116,23 +116,20 @@
int last_index = size-1;
Object last_value = data.get(last_index).getAttrValue(this.splitDomain.getFName());
SplitPoint bestPoint = new SplitPoint(last_index, last_value);
- //split_points.get(end_index-1);
- //SplitPoint bestPoint = split_points.get(split_points.size()-1);
CondClassDistribution best_distribution = null;//instances_by_attr;
Instance i1 = data.get(begin_index), i2;
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("\nentropy.info_cont() SEARCHING: "+begin_index+ " until "+size+" attr "+this.splitDomain.getFName()+ " "+ i1 );
- }
+ if (flog.debug() !=null)
+ flog.debug().log("\nentropy.info_cont() SEARCHING: "+begin_index+ " until "+size+" attr "+this.splitDomain.getFName()+ " "+ i1);
for(int index =begin_index+1; index < size; index ++) {
i2= data.get(index);
/* every time i read a new instance and change the place in the distribution */
- instances_by_attr.change(key0, i1.getAttrValue(this.targetDomain.getFName()), +1);
- instances_by_attr.change(key1, i1.getAttrValue(this.targetDomain.getFName()), -1);
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("BOK " +i1+" vs "+i2);
- }
+ instances_by_attr.change(key0, i1.getAttrValue(this.targetDomain.getFName()), +1.0d * i1.getWeight()); //+1
+ instances_by_attr.change(key1, i1.getAttrValue(this.targetDomain.getFName()), -1.0d * i1.getWeight()); //-1
+
+ if (flog.debug() !=null)
+ flog.debug().log("Instances " +i1+" vs "+i2);
/*
* CATEGORIZATION 2.1. Cut points are points in the sorted list above where the class labels change.
* Eg. if I had five instances with values for the attribute of interest and labels
@@ -142,9 +139,9 @@
*/
if ( targetComp.compare(i1, i2)!=0 && attrComp.compare(i1, i2)!=0) {
num_split_points++;
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("entropy.info_cont() SEARCHING: "+(index)+" attr "+this.splitDomain.getFName()+ " "+ i2 );
- }
+
+ if (flog.debug() !=null)
+ flog.debug().log("entropy.info_cont() SEARCHING: "+(index)+" attr "+this.splitDomain.getFName()+ " "+ i2);
// the cut point
Object cp_i = i1.getAttrValue(this.splitDomain.getFName());
Object cp_i_next = i2.getAttrValue(this.splitDomain.getFName());
@@ -162,11 +159,10 @@
if (sum < best_sum) {
best_sum = sum;
split_index = index;
-
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println(Util.ntimes("?", 10)+"** FOUND: @"+(index)+" target ("+ i1.getAttrValue(this.targetDomain.getFName())
- +"-|T|-"+ i2.getAttrValue(this.targetDomain.getFName())+")");
- }
+
+ if (flog.debug() !=null)
+ flog.debug().log(Util.ntimes("?", 10)+"** FOUND: @"+(index)+" target ("+ i1.getAttrValue(this.targetDomain.getFName())
+ +"-|T|-"+ i2.getAttrValue(this.targetDomain.getFName())+")");
bestPoint = new SplitPoint(index-1, cut_point);
bestPoint.setInformationValue(best_sum);
@@ -179,18 +175,17 @@
}
if (best_distribution != null) {
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("bp:"+ bestPoint);
- }
+ if (flog.debug() !=null)
+ flog.debug().log("bp:"+ bestPoint);
this.splitDomain.addSplitPoint(bestPoint);
/*
* TODO : can we put the conditional class distribution to its correct place instead of the split
*/
- if (Util.DEBUG_CATEGORIZER) {
- System.out.println("bd:"+ best_distribution.getNumCondClasses());
- }
+
+ if (flog.debug() !=null)
+ flog.debug().log("bd:"+ best_distribution.getNumCondClasses());
double sum1 = find_a_split(begin_index, split_index, depth-1, best_distribution.getDistributionOf(key0));
double sum2 = find_a_split(split_index, size, depth-1, best_distribution.getDistributionOf(key1));
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/ClassDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/ClassDistribution.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/ClassDistribution.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -10,8 +10,12 @@
* class: categories of the target attribute*/
public class ClassDistribution {
+ //private static final Logger log = LoggerFactory.getSysOutLogger(LogLevel.ERROR);
+ //protected static final Logger flog = LoggerFactory.getFileLogger(ClassDistribution.class, LogLevel.ERROR, Util.log_file);
+
+
protected Domain target_attr;
- private Hashtable<Object, Integer> quantity_by_class;
+ private Hashtable<Object, Double> quantity_by_class;
private String sum_key = Util.sum();
private int num_category_ideas;
@@ -20,20 +24,20 @@
public ClassDistribution(Domain targetDomain) {
this.target_attr = targetDomain;
- this.quantity_by_class = new Hashtable<Object, Integer>(this.target_attr.getCategoryCount() + 1);
+ this.quantity_by_class = new Hashtable<Object, Double>(this.target_attr.getCategoryCount() + 1);
for (int c=0; c<this.target_attr.getCategoryCount(); c++) {
Object category = this.target_attr.getCategory(c);
- quantity_by_class.put(category, 0);
+ quantity_by_class.put(category, 0.0d);
}
- quantity_by_class.put(sum_key, 0);
+ quantity_by_class.put(sum_key, 0.0d);
num_category_ideas = 0;
}
public ClassDistribution(ClassDistribution copy_dist) {
this.target_attr = copy_dist.getClassDomain();
- this.quantity_by_class = new Hashtable<Object, Integer>(this.target_attr.getCategoryCount() + 1);
+ this.quantity_by_class = new Hashtable<Object, Double>(this.target_attr.getCategoryCount() + 1);
this.setDistribution(copy_dist);
this.num_category_ideas = copy_dist.get_num_ideas();
@@ -41,35 +45,35 @@
}
- public void setSum(int sum) {
+ public void setSum(double sum) {
quantity_by_class.put(sum_key, sum);
}
- public int getSum() {
- return quantity_by_class.get(sum_key).intValue();
+ public double getSum() {
+ return quantity_by_class.get(sum_key);
}
public Domain getClassDomain() {
return target_attr;
}
- public void change(Object target_category, int i) {
+ public void change(Object target_category, double i) {
/* TODO ????
* if (target_category == sum_key) return;
*/
- int num_1 = quantity_by_class.get(target_category).intValue();
+ double num_1 = quantity_by_class.get(target_category);
num_1 += i;
quantity_by_class.put(target_category, num_1);
//quantity_by_class.put(target_category, quantity_by_class.get(target_category)+i);
}
- public int getVoteFor(Object targetCategory) {
- return quantity_by_class.get(targetCategory).intValue();
+ public double getVoteFor(Object targetCategory) {
+ return quantity_by_class.get(targetCategory);
}
public void evaluateMajority() {
- int winner_vote = 0;
+ double winner_vote = 0.0d;
int num_ideas = 0; // the number of target categories that the instances belong to
Object winner = null;
@@ -77,7 +81,7 @@
for (int c=0; c<this.target_attr.getCategoryCount(); c++) {
Object category = this.target_attr.getCategory(c);
- int num_in_class = this.getVoteFor(category);
+ double num_in_class = this.getVoteFor(category);
if (num_in_class > 0) {
num_ideas++;
if (num_in_class > winner_vote) {
@@ -108,13 +112,13 @@
}
public String toString() {
- String out = "ClassDist: target:"+ this.target_attr.getFName()+ " total: "+ this.getSum() + " & categories:";
+ StringBuffer sb_out = new StringBuffer("ClassDist: target:"+ this.target_attr.getFName()+ " total: "+ this.getSum() + " & categories:");
for (int c=0; c<this.target_attr.getCategoryCount(); c++) {
Object category = this.target_attr.getCategory(c);
- out += this.getVoteFor(category) +" @"+category+ ", ";
+ sb_out.append(this.getVoteFor(category) +" @"+category+ ", ");
}
// out +="\n";
- return out;
+ return sb_out.toString();
}
public void setDistribution(ClassDistribution targetDist) {
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CondClassDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CondClassDistribution.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/CondClassDistribution.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -12,7 +12,7 @@
private Domain cond_attr; // domain of the attr we distribute the instances conditionally
private Hashtable<Object, ClassDistribution> cond_quantity_by_class;
- private int total_num;
+ private double total_num;
public CondClassDistribution(Domain attributeDomain, Domain targetDomain) {
this.cond_attr = attributeDomain;
@@ -38,11 +38,11 @@
}
- public int getTotal() {
+ public double getTotal() {
return this.total_num;
}
- public void setTotal(int size) {
+ public void setTotal(double size) {
this.total_num = size;
}
@@ -61,11 +61,11 @@
return cond_attr.getCategoryCount();
}
- public int getTotal_AttrCategory(Object attr_value) {
+ public double getTotal_AttrCategory(Object attr_value) {
return cond_quantity_by_class.get(attr_value).getSum();
}
- public void change(Object attr_category, Object target_class, int i) {
+ public void change(Object attr_category, Object target_class, double i) {
// System.out.print("The cond_dist: a_cat:"+ attr_category+" t_cat:"+ target_class);
// System.out.println(" the nums: "+cond_quantity_by_class.get(attr_category));
@@ -89,13 +89,13 @@
}
public String toString() {
- String out = "\nCondClassDist: attr: "+cond_attr.getFName()+" total num: "+ this.getTotal() + "\n" ;
+ StringBuffer sb_out = new StringBuffer("\nCondClassDist: attr: "+cond_attr.getFName()+" total num: "+ this.getTotal() + "\n") ;
for (int c_idx = 0; c_idx<cond_attr.getCategoryCount(); c_idx++) {
Object attr_category = cond_attr.getCategory(c_idx);
ClassDistribution td = cond_quantity_by_class.get(attr_category);
- out += "(ATTR:"+attr_category+ "=> "+ td +")";
+ sb_out.append("(ATTR:"+attr_category+ "=> "+ td +")");
}
- return out;
+ return sb_out.toString();
}
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Entropy.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Entropy.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/Entropy.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -7,10 +7,13 @@
import org.drools.learner.Domain;
import org.drools.learner.Instance;
import org.drools.learner.QuantitativeDomain;
+import org.drools.learner.tools.LoggerFactory;
+import org.drools.learner.tools.SimpleLogger;
import org.drools.learner.tools.Util;
public class Entropy {
+ private static SimpleLogger flog = LoggerFactory.getUniqueFileLogger(Entropy.class, SimpleLogger.DEFAULT_LEVEL);
//public Entropy
/*
* - chooses the best attribute,
@@ -43,14 +46,14 @@
Categorizer visitor = new Categorizer(insts_by_target);
visitor.findSplits(trialDomain);
- // trial domain modified
+ // trial domain is modified
if (trialDomain.getNumIndices()==1) {
gain = 0.0;
} else {
- gain = dt_info - info_contattr(visitor.getInstances(), insts_by_target.getClassDomain(), trialDomain);
+ gain = dt_info - info_contattr(visitor.getSortedInstances(), insts_by_target.getClassDomain(), trialDomain);
}
attr_domain = trialDomain;
- sorted_instances = visitor.getInstances();
+ sorted_instances = visitor.getSortedInstances();
/*
if (Util.DEBUG) {
@@ -65,7 +68,7 @@
}
/* */
}
- if (Util.DEBUG_ENTROPY) System.out.println("Attribute: " + attr_domain + " the gain: " + gain);
+ //flog.debug("Attribute: " + attr_domain + " the gain: " + gain);
if (gain > greatestGain) {
greatestGain = gain;
best_attr = attr_domain;
@@ -76,8 +79,7 @@
// Clone the best attribute domain cause it is going to be the domain of the treenode
eval.domain = best_attr.cheapClone();
eval.sorted_data = best_sorted_instances;
- eval.gain = greatestGain;
- //eval.gain_ratio = greatestGain/greatestSplitGain;
+ eval.attribute_eval = greatestGain;
return eval.domain;
}
@@ -107,7 +109,7 @@
* All domains are categorical so i will use them the way they are
*/
double gain = dt_info - info_attr(insts_by_target, attr_domain);
- //if (Util.DEBUG) System.out.println("Attribute: " + attr + " the gain: " + gain);
+ //flog.debug("Attribute: " + attr_domain.getFName() + " the gain: " + gain);
if (gain > greatestGain) {
greatestGain = gain;
best_attr = attr_domain;
@@ -122,17 +124,13 @@
Domain target_domain = insts_by_target.getClassDomain();
- if (Util.DEBUG_ENTROPY) {
- System.out.println("What is the attributeToSplit? " + attr_domain);
- }
+ //flog.debug("What is the attributeToSplit? " + attr_domain);
/* initialize the hashtable */
CondClassDistribution insts_by_attr = new CondClassDistribution(attr_domain, target_domain);
insts_by_attr.setTotal(insts_by_target.getSum());
- if (Util.DEBUG_ENTROPY) {
- System.out.println("Cond distribution for "+ attr_domain + " \n"+ insts_by_attr);
- }
+ //flog.debug("Cond distribution for "+ attr_domain + " \n"+ insts_by_attr);
for (int category = 0; category<target_domain.getCategoryCount(); category++) {
Object targetCategory = target_domain.getCategory(category);
@@ -143,10 +141,11 @@
Object inst_class = inst.getAttrValue(target_domain.getFName());
if (!targetCategory.equals(inst_class)) {
- System.out.println("How the fuck they are not the same ? "+ targetCategory + " " + inst_class);
+ if (flog.error() != null)
+ flog.error().log("How the fuck they are not the same ? "+ targetCategory + " " + inst_class);
System.exit(0);
}
- insts_by_attr.change(inst_attr_category, targetCategory, +1);
+ insts_by_attr.change(inst_attr_category, targetCategory, inst.getWeight()); //+1
}
}
@@ -158,6 +157,7 @@
/* calculates the information of a quantitative domain given the split indexes of instances
* a wrapper for the quantitative domain to be able to calculate the stats
* */
+ //public static double info_contattr(InstanceList data, Domain targetDomain, QuantitativeDomain splitDomain) {
public static double info_contattr(List<Instance> data, Domain targetDomain, QuantitativeDomain splitDomain) {
String targetAttr = targetDomain.getFName();
@@ -175,8 +175,7 @@
split_index++;
}
Object targetKey = i.getAttrValue(targetAttr);
- instances_by_attr.change(attr_key, targetKey, +1);
-
+ instances_by_attr.change(attr_key, targetKey, i.getWeight()); //+1
index++;
}
@@ -190,29 +189,23 @@
*/
public static double calc_info_attr( CondClassDistribution instances_by_attr) {
//Collection<Object> attributeValues = instances_by_attr.getAttributes();
- int data_size = instances_by_attr.getTotal();
+ double data_size = instances_by_attr.getTotal();
double sum = 0.0;
if (data_size>0)
for (int attr_idx=0; attr_idx<instances_by_attr.getNumCondClasses(); attr_idx++) {
Object attr_category = instances_by_attr.getCondClass(attr_idx);
- int total_num_attr = instances_by_attr.getTotal_AttrCategory(attr_category);
+ double total_num_attr = instances_by_attr.getTotal_AttrCategory(attr_category);
if (total_num_attr > 0) {
- double prob = (double) total_num_attr / (double) data_size;
- if (Util.DEBUG_ENTROPY) {
- System.out.print("{("+total_num_attr +"/"+data_size +":"+prob +")* [");
- }
+ double prob = total_num_attr / data_size;
+ //flog.debug("{("+total_num_attr +"/"+data_size +":"+prob +")* [");
double info = calc_info(instances_by_attr.getDistributionOf(attr_category));
sum += prob * info;
- if (Util.DEBUG_ENTROPY) {
- System.out.print("]} ");
- }
+ //flog.debug("]} ");
}
}
- if (Util.DEBUG_ENTROPY) {
- System.out.println("\n == "+sum);
- }
+ //flog.debug("\n == "+sum);
return sum;
}
@@ -225,28 +218,23 @@
*/
public static double calc_info(ClassDistribution quantity_by_class) {
- int data_size = quantity_by_class.getSum();
+ double data_size = quantity_by_class.getSum();
double prob, sum = 0;
- String out =" ";
Domain target_domain = quantity_by_class.getClassDomain();
for (int category = 0; category<target_domain.getCategoryCount(); category++) {
Object targetCategory = target_domain.getCategory(category);
- int num_in_class = quantity_by_class.getVoteFor(targetCategory);
+ double num_in_class = quantity_by_class.getVoteFor(targetCategory);
if (num_in_class > 0) {
- prob = (double) num_in_class / (double) data_size;
+ prob = num_in_class / data_size;
/* TODO what if it is a sooo small number ???? */
- if (Util.DEBUG_ENTROPY) {
- out += "("+num_in_class+ "/"+data_size+":"+prob+")" +"*"+ Util.log2(prob) + " + ";
- }
+ //flog.debug("("+num_in_class+ "/"+data_size+":"+prob+")" +"*"+ Util.log2(prob) + " + ");
sum -= prob * Util.log2(prob);
}
}
- if (Util.DEBUG_ENTROPY) {
- System.out.print(out +"= " +sum);
- }
+ //flog.debug("= " +sum);
return sum;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InformationContainer.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -8,17 +8,18 @@
public class InformationContainer {
public Domain domain;
- public double gain;
- public double gain_ratio;
+ public double attribute_eval;
+ //public double gain_ratio;
public ArrayList<Instance> sorted_data;
public InformationContainer() {
}
- public InformationContainer(Domain _domain, double _gain, double _gain_ratio) {
- this.domain = _domain;
- this.gain = _gain;
- this.gain_ratio = _gain_ratio;
- }
+
+// public InformationContainer(Domain _domain, double _attribute_eval, double _gain_ratio) {
+// this.domain = _domain;
+// this.attribute_eval = _attribute_eval;
+// this.gain_ratio = _gain_ratio;
+// }
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/eval/InstDistribution.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -31,14 +31,15 @@
}
public void calculateDistribution(List<Instance> instances){
- int data_size = 0;
+ double data_size = 0.0;
String tName = super.getClassDomain().getFName();
for (Instance inst : instances) {
- data_size++;
+ data_size += inst.getWeight();
Object target_key = inst.getAttrValue(tName);
- super.change(target_key, +1); // add one for vote for the target value : target_key
+ super.change(target_key, inst.getWeight()); // add one for vote for the target value : target_key
+ //super.change(attr_sum, inst.getWeight()); // ?????
+
this.addSupporter(target_key, inst);
-
}
//super.change(attr_sum, data_size); // TODO should i write special function for changing the sum
super.setSum(data_size);
@@ -88,7 +89,8 @@
}
- public Hashtable<Object, InstDistribution> splitFromCategorical(Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
+ public Hashtable<Object, InstDistribution> splitFromCategorical(
+ Domain splitDomain, Hashtable<Object, InstDistribution> instLists) {
if (instLists == null)
instLists = this.instantiateLists(splitDomain);
@@ -101,8 +103,8 @@
for (Instance inst: this.getSupportersFor(targetCategory)) {
Object inst_attr_category = inst.getAttrValue(attrName);
- instLists.get(inst_attr_category).change(targetCategory, +1); // add one for vote for the target value : target_key
- instLists.get(inst_attr_category).change(attr_sum, +1);
+ instLists.get(inst_attr_category).change(targetCategory, inst.getWeight()); // add one for vote for the target value : target_key
+ instLists.get(inst_attr_category).change(attr_sum, inst.getWeight());
instLists.get(inst_attr_category).addSupporter(targetCategory, inst);
}
@@ -116,33 +118,8 @@
String attributeName = attributeDomain.getFName();
String targetName = super.getClassDomain().getFName();
- if (Util.DEBUG_DIST) {
- System.out.println("FactProcessor.splitFacts_cont() attr_split "+ attributeName);
- }
+ //flog.debug("FactProcessor.splitFacts_cont() attr_split "+ attributeName);
-// if (Util.DEBUG_DISTRIBUTION) {
-// System.out.println("FactProcessor.splitFacts_cont() haniymis benim repsentativelerim: "+ splitValues.size() + " and the split points "+ splitIndices.size());
-//
-// System.out.println("FactProcessor.splitFacts_cont() before splitting "+ facts.size());
-//
-// int index = 0;
-// int split_index = 0;
-// Object attr_key = splitValues.get(split_index);
-// for (Fact f : facts) {
-//
-// if (index == splitIndices.get(split_index).intValue()+1 ) {
-// System.out.print("PRINT* (");
-// attr_key = splitValues.get(split_index+1);
-// split_index++;
-// } else {
-// System.out.print("PRINT (");
-// }
-// System.out.println(split_index+"): fact "+f);
-// index++;
-// }
-//
-// }
-
int start_point = 0;
for (int index = 0; index < attributeDomain.getNumIndices(); index ++) {
@@ -153,19 +130,17 @@
try {
- if (Util.DEBUG_DIST) {
- System.out.println("FactProcessor.splitFacts_cont() new category: "+ inst_attr_category );
- System.out.println(" ("+start_point+","+integer_index+")");
- }
+ //flog.debug("FactProcessor.splitFacts_cont() new category: "+ inst_attr_category+
+ // " ("+start_point+","+integer_index+")");
List<Instance> data_at_category = data.subList(start_point, integer_index+1);
- for (Instance i: data_at_category) {
+ for (Instance inst: data_at_category) {
- Object targetCategory = i.getAttrValue(targetName);
+ Object targetCategory = inst.getAttrValue(targetName);
- instLists.get(inst_attr_category).change(targetCategory, +1); // add one for vote for the target value : target_key
- instLists.get(inst_attr_category).change(attr_sum, +1);
- instLists.get(inst_attr_category).addSupporter(targetCategory, i);
+ instLists.get(inst_attr_category).change(targetCategory, inst.getWeight()); // add one for vote for the target value : target_key
+ instLists.get(inst_attr_category).change(attr_sum, inst.getWeight());
+ instLists.get(inst_attr_category).addSupporter(targetCategory, inst);
}
start_point = integer_index+1;
@@ -181,7 +156,7 @@
for (int idx = 0; idx < super.target_attr.getCategoryCount(); idx++) {
Object looser = super.target_attr.getCategory(idx);
- int num_supp = this.getVoteFor(looser);
+ double num_supp = this.getVoteFor(looser);
if ((num_supp > 0) && !winner.equals(looser)) {
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/FeatureNotSupported.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/FeatureNotSupported.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/FeatureNotSupported.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -2,6 +2,11 @@
public class FeatureNotSupported extends Exception {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 1L;
+
public FeatureNotSupported(String string) {
super(string);
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/NumberComparator.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/NumberComparator.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/NumberComparator.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,9 +1,8 @@
package org.drools.learner.tools;
-import java.io.Serializable;
import java.util.Comparator;
-public class NumberComparator implements Comparator<Number>, Serializable {
+public class NumberComparator implements Comparator<Number> {
public NumberComparator() {
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/ObjectFactory.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/ObjectFactory.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/ObjectFactory.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -59,8 +59,13 @@
file =new File("src/main/java/org/drools/examples/learner/"+filename);
if(!file.exists()){
+ file =new File("drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/"+filename);
+
System.out.println("where is still the file ? "+ file);
- System.exit(0);
+ if(!file.exists()){
+ System.out.println("where is still still the file ? "+ file);
+ System.exit(0);
+ }
}
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/RulePrinter.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -16,32 +16,28 @@
import org.drools.learner.DecisionTree;
import org.drools.learner.LeafNode;
import org.drools.learner.TreeNode;
-import org.drools.learner.builder.Learner;
public class RulePrinter {
- public static Reader readRules(Learner learner) {
- if (learner.getTree() == null) {
- System.out.println("There is tree/rule to process");
- return null;
- }
+// private static final Logger log = LoggerFactory.getSysOutLogger(RulePrinter.class, LogLevel.ERROR);
+// private static final Logger flog = LoggerFactory.getFileLogger(RulePrinter.class, LogLevel.ERROR, Util.log_file);
+//
+
+ public static Reader readRules(DecisionTree learned_dt) {
RulePrinter my_printer = new RulePrinter(); //bocuk.getNum_fact_trained()
my_printer.setBoundOnNumRules(Util.MAX_NUM_RULES);
- my_printer.printer(learner.getTree(), Util.SORT_RULES_BY_RANK);
+ my_printer.printer(learned_dt, Util.SORT_RULES_BY_RANK);
String all_rules = my_printer.write2string();
if (Util.PRINT_RULES) {
//my_printer.write2file("examples", "src/rules/examples/" + file);
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println(all_rules);
- }
- my_printer.write2File(all_rules, false, "", learner.getDomainType(), 0);
+ my_printer.write2File(all_rules, false, Util.DRL_DIRECTORY+learned_dt.getSignature());
}
return new StringReader(all_rules);
}
-
+
private Class<?> rule_clazz;
private Stack<NodeValue> nodes;
@@ -50,7 +46,8 @@
//private ArrayList<String> ruleText;
- private int bound_on_num_rules, num_instances;
+ private int bound_on_num_rules;
+ private double num_instances;
//private NumberComparator nComparator;
@@ -62,7 +59,7 @@
//ruleText = new ArrayList<String>();
this.bound_on_num_rules = -1;
- this.num_instances = -1;
+ this.num_instances = -1.0d;
//this.nComparator = new NumberComparator();
}
@@ -79,7 +76,7 @@
this.bound_on_num_rules = max_num_rules;
}
- public int getNumInstances() {
+ public double getNumInstances() {
return this.num_instances;
}
@@ -141,47 +138,38 @@
return newRule;
}
- public void write2File(String toWrite, boolean append, String dataFile, int domain_type, int tree_set) {
-
- String packageFolders = this.getRuleClass().getPackage().getName();
-
- String _packageNames = packageFolders.replace('.', '/');
-
- String fileName = (dataFile == null || dataFile == "") ? this.getRuleClass().getSimpleName().toLowerCase(): dataFile;
-
- String suffix = Util.getFileSuffix(domain_type, tree_set);
- fileName += "_"+suffix + ".drl";
-
- String dataFileName = "src/main/rules/"+_packageNames+"/"+ fileName;
-
- System.out.println("file:"+ dataFileName);
- File file =new File(dataFileName);
+ public void write2File(String toWrite, boolean append, String fileSignature) {//DomainType domain_type, int tree_set
+
+ //String drlFileName =
+ if (!fileSignature.endsWith(".drl"))
+ fileSignature += ".drl";
+ System.out.println("file:"+ fileSignature);
+ File file =new File(fileSignature);
if (append)
{
if(!file.exists())
- System.out.println("File doesnot exit, creating...");
+ //flog.warn("File doesnot exit, creating...");
try {
- BufferedWriter out = new BufferedWriter(new FileWriter(dataFileName, true));
+ BufferedWriter out = new BufferedWriter(new FileWriter(fileSignature, true));
out.write(toWrite);
out.close();
//System.out.println("I wrote "+ toWrite);
} catch (IOException e) {
- System.out.println("No I cannot write to the file (appending) e:"+ e);
+ //flog.error("No I couldnot append to the file e:"+ e);
/* TODO */
}
} else {
- if(file.exists()&& (file.length()>0))
+ if(file.exists()&& (file.length()>0)) {
file.delete();
+ //flog.warn("File exits, deleting...");
+ }
try {
- BufferedWriter out = new BufferedWriter(new FileWriter(dataFileName));
+ BufferedWriter out = new BufferedWriter(new FileWriter(fileSignature));
out.write(toWrite);
out.close();
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("I wrote "+ toWrite);
- }
} catch (IOException e) {
- System.out.println("No I cannot write to the file (creating new file) e:"+ e);
+ //flog.error("No I couldnot create the file e:"+ e);
/* TODO */
}
}
@@ -192,29 +180,34 @@
String packageName = this.getRuleClass().getPackage().getName();
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("Package name: "+ packageName);
- }
+ //log.debug("Package name: "+ packageName);
+
if (packageName != null)
outputBuffer.append("package " + packageName +";\n\n");
else {
//TODO throw exception
+ //flog.error("RulePrinter write2string packageName="+packageName);
}
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("//Num of rules " +rules.size()+"\n");
- }
+// flog.debug(new Object() {
+// public String toString() {
+// String out = "//Num of rules " +rules.size()+"\n //this.getBoundOnNumRules() "+ getBoundOnNumRules();
+// return out;
+// }
+// });
+
int total_num_facts=0;
int i = 0, active_i = 0;
for( Rule rule: rules) {
i++;
+ //flog.debug("Rule: "+ i);
if (Util.ONLY_ACTIVE_RULES) {
if (rule.getRank() >= 0) {
active_i++;
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("//Active rules " +i + " write to drl \n"+ rule +"\n");
- }
+// if (Util.DEBUG_RULE_PRINTER) {
+// System.out.println("//Active rules " +i + " write to drl \n"+ rule +"\n");
+// }
outputBuffer.append(rule.toString());
outputBuffer.append("\n");
}
@@ -223,9 +216,9 @@
if (rule.getRank() >= 0) {
active_i++;
}
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("//rule " +i + " write to drl \n"+ rule +"\n");
- }
+// if (Util.DEBUG_RULE_PRINTER) {
+// System.out.println("//rule " +i + " write to drl \n"+ rule +"\n");
+// }
outputBuffer.append(rule.toString());
outputBuffer.append("\n");
}
@@ -248,7 +241,7 @@
private ArrayList<NodeValue> actions;
private double rank; // matching ratio
- private int num_classified_instances;// number of instances matching that rule
+ private double num_classified_instances;// number of instances matching that rule
private int id; // unique id, need a unique name in the drl file
@@ -285,10 +278,10 @@
this.id= id;
}
- private void setNumClassifiedInstances(int dataSize) {
+ private void setNumClassifiedInstances(double dataSize) {
this.num_classified_instances = dataSize;
}
- public int getNumClassifiedInstances() {
+ public double getNumClassifiedInstances() {
return this.num_classified_instances;
}
@@ -304,39 +297,54 @@
rule "Good Bye"
dialect "java"
when
- Message( status == Message.GOODBYE, message : message )
+ $m:Message( status == Message.GOODBYE)
then
- System.out.println( "Goodbye: " + message );
+ System.out.println( "[getLabel()] Expected value (" + $c.getLabel() + "), Classified as (False)");
end
*/
-
- String out = ""; //"rule \"#"+getId()+" "+decision+" rank:"+rank+"\" \n";
+ //"rule \"#"+getId()+" "+decision+" rank:"+rank+"\" \n";
+ StringBuffer sb_out = new StringBuffer("");
+ String obj_ref = "$"+this.getObjectClassName().substring(0, 1).toLowerCase();
- out += "\t when";
- out += "\n\t\t "+this.getObjectClassName() +"("+ "";
+ sb_out.append("\t when");
+ sb_out.append("\n\t\t "+obj_ref+":"+this.getObjectClassName() +"("+ "");
for (NodeValue cond: conditions) {
- out += cond + ", ";
+ sb_out.append(cond.toString() + ", ");
}
- String action = "";
- String decision = "";
+ StringBuffer sb_action = new StringBuffer("");
+ StringBuffer sb_field = new StringBuffer("");
+ StringBuffer sb_expected_value = new StringBuffer("");
for (NodeValue act: actions) {
- out += act.getFName() + " : "+act.getFName()+" , ";
- action += act.getNodeValue() + " , ";
- decision += act.getFName() + " ";
+ // if the query is on a field then i have to get its value during in the rule 'cause it might be private
+ if (!act.getNode().getDomain().isArtificial())
+ sb_out.append(obj_ref+ "_"+act.getFName() + " : "+act.getFName()+", ");
+
+ sb_action.append(act.getNodeValue() + " , ");
+ if (!act.getNode().getDomain().isArtificial())
+ sb_field.append(act.getFName() + "");
+ else
+ sb_field.append(act.getFName() + "()");
+
+
+ if (!act.getNode().getDomain().isArtificial())
+ sb_expected_value.append(obj_ref+ "_"+act.getFName());//reading the value by the reference of $o_fieldname
+ else
+ sb_expected_value.append(obj_ref+ "."+act.getFName() + "()");// reading the value from the object $o.function()
+
}
- action = action.substring(0, action.length()-3);
- out = out.substring(0, out.length()-3) + ")\n";
-
- out += "\t then ";
- out += "\n\t\t System.out.println(\"Decision on "+decision+"= \"+" + decision + "+\": ("+action+")\");\n";
+ sb_action.delete(sb_action.length()-3, sb_action.length()-1);
+ sb_out.delete(sb_out.length()-2, sb_out.length()-1);
+ sb_out.append(")\n");
+ sb_out.append("\t then ");
+ sb_out.append("\n\t\t System.out.println(\"["+sb_field.toString()+ "] Expected value (\" + "+ sb_expected_value.toString()+ " + \"), Classified as ("+sb_action.toString()+")\");\n");
if (getRank() <0)
- out += "\n\t\t System.out.println(\"But no matching fact found = DOES not fire on\");\n";
- out = "rule \"#"+getId()+" "+decision+ "= "+action+" classifying "+this.getNumClassifiedInstances()+" num of facts with rank:"+getRank() +"\" \n" + out;
+ sb_out.append("\n\t\t System.out.println(\"But no matching fact found = DOES not fire on\");\n");
- out += "end\n";
+ sb_out.insert(0, "rule \"#"+getId()+" "+sb_field.toString()+ "= "+sb_action.toString()+" classifying "+this.getNumClassifiedInstances()+" num of facts with rank:"+getRank() +"\" \n");
+ sb_out.append("end\n");
- return out;
+ return sb_out.toString();
}
public static Comparator<Rule> getRankComparator() {
@@ -359,6 +367,9 @@
class NodeValue {
+
+ //private static final Logger flog = LoggerFactory.getFileLogger(NodeValue.class, LogLevel.ERROR, Util.log_file);
+
private TreeNode node;
private Object nodeValue; // should it be Attribute???
@@ -404,9 +415,9 @@
Object categoryValue = node.getDomain().getCategory(idx);
if (nodeValue instanceof Comparable && categoryValue instanceof Comparable) {
// TODO ask this to daniel???
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
- }
+// if (Util.DEBUG_RULE_PRINTER) {
+// System.out.println("NodeValue:"+ nodeValue+ " c-"+nodeValue.getClass() +" & category:"+ categoryValue+ " c-"+categoryValue.getClass());
+// }
if ( AttributeValueComparator.instance.compare(nodeValue, categoryValue) == 0 ) {
break;
}
@@ -431,9 +442,7 @@
else {
//return node.getDomain().getCategory(idx) + " < " + fName+ " <= "+ node.getDomain().getCategory(idx+1);
// Why drools does not support category(idx) < domain.name <= category(idx+1)
- if (Util.DEBUG_RULE_PRINTER) {
- System.out.println("value "+ value + "=====?????"+ node.getDomain().getCategory(idx+1));
- }
+ //flog.debug("value "+ value + "=====?????"+ node.getDomain().getCategory(idx+1));
return fName+ " <= "+ value; // node.getDomain().getCategory(idx+1);
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/Util.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-core/src/main/java/org/drools/learner/tools/Util.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -7,36 +7,26 @@
public class Util {
- public static final int ID3 = 1, C45 = 2;
- public static final int SINGLE = 1, BAG = 2, BOOST = 3;
-
- //public static final boolean DEBUG = false;
- public static final boolean DEBUG_TEST = false;
-
public static final boolean PRINT_STATS = true;
-
+ public static final String DRL_DIRECTORY = "src/main/rules/";
+ /*
+ public static final boolean DEBUG = false;
+ public static final boolean DEBUG_TEST = false;
public static final boolean DEBUG_RULE_PRINTER = false;
-
public static final boolean DEBUG_LEARNER = false;
-
public static final boolean DEBUG_CATEGORIZER = false;
-
public static final boolean DEBUG_ENTROPY = false;
-
public static final boolean DEBUG_DIST = false;
-
public static final boolean DEBUG_DECISION_TREE = false;
-
-
-
- public static int MAX_NUM_RULES = 1000;
+ */
+ public static int MAX_NUM_RULES = 10;
public static boolean ONLY_ACTIVE_RULES = true; /* TODO into global settings */
public static boolean SORT_RULES_BY_RANK = true;
public static boolean PRINT_RULES = true;
- private static boolean WITH_REP = true;
private static Random BAGGING = new Random(System.currentTimeMillis());
+ //public static String log_file = "testing.log";
public static String ntimes(String s,int n){
@@ -55,6 +45,14 @@
return Math.log(prob) / Math.log(2);
}
+ public static double ln(double prob) {
+ return Math.log(prob);
+ }
+
+ public static double exp(double prob) {
+ return Math.exp(prob);
+ }
+
/* TODO make this all_fields arraylist as hashmap */
public static void getAllFields(Class<?> clazz, ArrayList<Field> all_fields) {
if (clazz == Object.class)
@@ -69,6 +67,21 @@
return;
}
+
+ /* TODO make this all_fields arraylist as hashmap */
+ public static void getAllFields(Class<?> clazz, ArrayList<Field> all_fields, ArrayList<Class<?>> all_classes) {
+ if (clazz == Object.class)
+ return;
+ all_classes.add(clazz);
+ //Field [] element_fields_ = clazz.getFields();
+ Field [] element_fields = clazz.getDeclaredFields(); //clazz.getFields();
+ for (Field f: element_fields) {
+ all_fields.add(f);
+ }
+ getAllFields(clazz.getSuperclass(), all_fields, all_classes);
+
+ return;
+ }
public static Object calculateMidPoint(Class<?> fClass, Object cp_i, Object cp_i_next) {
if (fClass.isAssignableFrom(Integer.class) || fClass == Integer.TYPE) {
@@ -191,38 +204,7 @@
}
- public static String getFileSuffix(int domain_type, int tree_set) {
- String suffix = "";
- switch (domain_type) {
- case Util.ID3:
- suffix += "id3" ;
- break;
- case Util.C45:
- suffix += "c45";
- break;
- default:
- suffix += "?" ;
-
- }
-
- switch (tree_set) {
- case Util.SINGLE:
- suffix += "_one";
- break;
- case Util.BAG:
- suffix += "_bag";
- break;
- case Util.BOOST:
- suffix += "_boost";
- break;
- default:
- suffix += "_?" ;
-
- }
- return suffix;
- }
-
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Car.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Car.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Car.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,8 +1,9 @@
package org.drools.examples.learner;
+import org.drools.learner.tools.ClassAnnotation;
import org.drools.learner.tools.FieldAnnotation;
-
+ at ClassAnnotation(label_element = "getLabel2")
public class Car {
@FieldAnnotation(readingSeq = 0)
private String buying; //"vhigh", "high", "med", "low"
@@ -16,10 +17,12 @@
private String lug_boot; //"small", "med", "big"
@FieldAnnotation(readingSeq = 5)
private String safety; //"low", "med", "high"
- @FieldAnnotation(readingSeq = 6, target = true)
+ @FieldAnnotation(readingSeq = 6) //, target = true)
private String target; //"unacc", "acc", "good", "vgood"
-
+ public boolean getLabel2() {
+ return (doors.equalsIgnoreCase("5more") && safety.equalsIgnoreCase("med") && buying.equalsIgnoreCase("low"));
+ }
public Car() {
}
@@ -80,6 +83,11 @@
}
+ public boolean getLabel() {
+ return (doors.equalsIgnoreCase("5more"));
+ }
+
+
public String toString() {
String out = "Car(buy:" +getBuying() +
" doors:"+getDoors()+
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/CarExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -9,8 +9,9 @@
import org.drools.compiler.PackageBuilder;
import org.drools.event.DebugAgendaEventListener;
import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.builder.DecisionTreeBuilder;
+import org.drools.learner.builder.DecisionTreeFactory;
import org.drools.learner.tools.ObjectFactory;
public class CarExample {
@@ -22,12 +23,11 @@
final StatefulSession session = ruleBase.newStatefulSession(); // LearningSession
- // what are these listeners???
- session.addEventListener( new DebugAgendaEventListener() );
- session.addEventListener( new DebugWorkingMemoryEventListener() );
+ //session.addEventListener( new DebugAgendaEventListener() );
+ //session.addEventListener( new DebugWorkingMemoryEventListener() );
- final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
- logger.setFileName( "log/car" );
+ //final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
+ //logger.setFileName( "log/car" );
String inputFile = new String("data/car/car.data.txt");
Class<?> obj_class = Car.class;
@@ -37,12 +37,26 @@
}
// instantiate a learner for a specific object class and pass session to train
- Learner learner = LearnerFactory.createID3(session, obj_class);
- //Learner learner = LearnerFactory.createC45(session, obj_class);
+ DecisionTree decision_tree; int ALGO = 2;
+ switch (ALGO) {
+ case 1:
+ decision_tree = DecisionTreeFactory.createBaggedC45(session, obj_class);
+ break;
+ case 2:
+ decision_tree = DecisionTreeFactory.createBoostedC45(session, obj_class);
+ break;
+ case 3:
+ decision_tree = DecisionTreeFactory.createSingleID3(session, obj_class);
+ break;
+ default:
+ decision_tree = DecisionTreeFactory.createSingleC45(session, obj_class);
+ }
+
final PackageBuilder builder = new PackageBuilder();
//this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
+ builder.addPackageFromTree( decision_tree );
+ System.exit(0);
ruleBase.addPackage( builder.getPackage() );
/*
final Reader source = new InputStreamReader( HelloWorldExample.class.getResourceAsStream( "HelloWorld.drl" ) );
@@ -54,7 +68,7 @@
session.fireAllRules();
- logger.writeToDisk();
+ //logger.writeToDisk();
session.dispose();
}
Deleted: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45Example.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45Example.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45Example.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,61 +0,0 @@
-package org.drools.examples.learner;
-
-import java.util.List;
-
-import org.drools.RuleBase;
-import org.drools.RuleBaseFactory;
-import org.drools.StatefulSession;
-import org.drools.audit.WorkingMemoryFileLogger;
-import org.drools.compiler.PackageBuilder;
-import org.drools.event.DebugAgendaEventListener;
-import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
-import org.drools.learner.tools.ObjectFactory;
-
-public class GolfC45Example {
-
- public static final void main(final String[] args) throws Exception {
- // my rule base
- final RuleBase ruleBase = RuleBaseFactory.newRuleBase();
- //ruleBase.addPackage( pkg );
-
- final StatefulSession session = ruleBase.newStatefulSession();
- // LearningSession
-
- // what are these listeners???
- session.addEventListener( new DebugAgendaEventListener() );
- session.addEventListener( new DebugWorkingMemoryEventListener() );
-
- final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
- logger.setFileName( "log/golf_c45" );
-
- String inputFile = new String("data/golf/golf.data.txt");
- Class<?> obj_class = Golf.class;
- List<Object> facts = ObjectFactory.getObjects(obj_class, inputFile);
- for (Object r : facts) {
- session.insert(r);
- }
-
- // instantiate a learner for a specific object class and pass session to train
- Learner learner ;
- //learner = LearnerFactory.createC45(session, obj_class);
- learner = LearnerFactory.createC45fromBag(session, obj_class);
-
- final PackageBuilder builder = new PackageBuilder();
- //this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
- /*
- * get the compiled package (which is serializable) from the builder
- * add the package to a rulebase (deploy the rule package).
- */
- ruleBase.addPackage( builder.getPackage() );
-
- session.fireAllRules();
-
- logger.writeToDisk();
-
- session.dispose();
- }
-
-}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45ExampleFromDrl.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45ExampleFromDrl.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45ExampleFromDrl.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -19,7 +19,7 @@
public static final void main(final String[] args) throws Exception {
//read in the source
//final Reader source = new InputStreamReader( HelloWorldExample.class.getResourceAsStream( "HelloWorld.drl" ) );
- final Reader source = new InputStreamReader( Restaurant.class.getResourceAsStream( "golf2.drl" ) );
+ final Reader source = new InputStreamReader( Golf.class.getResourceAsStream( "golf2.drl" ) );
final PackageBuilder builder = new PackageBuilder();
Deleted: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,62 +0,0 @@
-package org.drools.examples.learner;
-
-import java.util.List;
-
-import org.drools.RuleBase;
-import org.drools.RuleBaseFactory;
-import org.drools.StatefulSession;
-import org.drools.audit.WorkingMemoryFileLogger;
-import org.drools.compiler.PackageBuilder;
-import org.drools.event.DebugAgendaEventListener;
-import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
-import org.drools.learner.tools.ObjectFactory;
-
-public class GolfExample {
-
- public static final void main(final String[] args) throws Exception {
- // my rule base
- final RuleBase ruleBase = RuleBaseFactory.newRuleBase();
- //ruleBase.addPackage( pkg );
-
- final StatefulSession session = ruleBase.newStatefulSession();
- // LearningSession
-
- // what are these listeners???
- session.addEventListener( new DebugAgendaEventListener() );
- session.addEventListener( new DebugWorkingMemoryEventListener() );
-
- final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
- logger.setFileName( "log/golf" );
-
- String inputFile = new String("data/golf/golf.data.txt");
- Class<?> obj_class = Golf.class;
- List<Object> facts = ObjectFactory.getObjects(obj_class, inputFile);
- for (Object r : facts) {
- session.insert(r);
- }
-
- // instantiate a learner for a specific object class and pass session to train
- Learner learner = LearnerFactory.createID3(session, obj_class);
-
- final PackageBuilder builder = new PackageBuilder();
- //this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
- ruleBase.addPackage( builder.getPackage() );
- /*
- final Reader source = new InputStreamReader( HelloWorldExample.class.getResourceAsStream( "HelloWorld.drl" ) );
- //get the compiled package (which is serializable)
- final Package pkg = builder.getPackage();
- //add the package to a rulebase (deploy the rule package).
- ruleBase.addPackage( pkg );
- */
-
- session.fireAllRules();
-
- logger.writeToDisk();
-
- session.dispose();
- }
-
-}
Copied: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java (from rev 20164, labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfC45Example.java)
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java (rev 0)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/GolfExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -0,0 +1,94 @@
+package org.drools.examples.learner;
+
+import java.util.List;
+
+import org.drools.RuleBase;
+import org.drools.RuleBaseFactory;
+import org.drools.StatefulSession;
+import org.drools.audit.WorkingMemoryFileLogger;
+import org.drools.compiler.PackageBuilder;
+import org.drools.event.DebugAgendaEventListener;
+import org.drools.event.DebugWorkingMemoryEventListener;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.builder.DecisionTreeFactory;
+import org.drools.learner.tools.ObjectFactory;
+
+public class GolfExample {
+
+ public static final void main(final String[] args) throws Exception {
+ long start_time = System.currentTimeMillis();
+ // my rule base
+ final RuleBase ruleBase = RuleBaseFactory.newRuleBase();
+ //ruleBase.addPackage( pkg );
+
+ final StatefulSession session = ruleBase.newStatefulSession();
+ // LearningSession
+
+ // what are these listeners???
+ session.addEventListener( new DebugAgendaEventListener() );
+ session.addEventListener( new DebugWorkingMemoryEventListener() );
+
+
+
+ final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
+ logger.setFileName( "log/golf_c45" );
+
+ String inputFile = new String("data/golf/golf.data.txt");
+ Class<?> obj_class = Golf.class;
+ List<Object> facts = ObjectFactory.getObjects(obj_class, inputFile);
+ for (Object r : facts) {
+ session.insert(r);
+ }
+
+ // instantiate a learner for a specific object class and pass session to train
+ DecisionTree decision_tree; int ALGO = 7;
+ switch (ALGO) {
+ case 1:
+ decision_tree = DecisionTreeFactory.createBaggedC45(session, obj_class);
+ break;
+ case 2:
+ decision_tree = DecisionTreeFactory.createBoostedC45(session, obj_class);
+ break;
+ case 3:
+ decision_tree = DecisionTreeFactory.createGlobal2(session, obj_class);
+ break;
+ case 4:
+ decision_tree = DecisionTreeFactory.createBagC45Entropy(session, obj_class);
+ break;
+ case 5:
+ decision_tree = DecisionTreeFactory.createBagC45GainRatio(session, obj_class);
+ break;
+ case 6:
+ decision_tree = DecisionTreeFactory.createSingleID3(session, obj_class);
+ break;
+ case 7:
+ decision_tree = DecisionTreeFactory.createSingleC45Entropy(session, obj_class);
+ break;
+ case 8:
+ decision_tree = DecisionTreeFactory.createSingleC45GainRatio(session, obj_class);
+ break;
+ default:
+ decision_tree = DecisionTreeFactory.createSingleC45(session, obj_class);
+
+ }
+
+ final PackageBuilder builder = new PackageBuilder();
+ //this wil generate the rules, then parse and compile in one step
+ builder.addPackageFromTree( decision_tree );
+ /*
+ * get the compiled package (which is serializable) from the builder
+ * add the package to a rulebase (deploy the rule package).
+ */
+ ruleBase.addPackage( builder.getPackage() );
+
+ session.fireAllRules();
+
+ long end_time = System.currentTimeMillis();
+ System.out.println("Total time="+ (end_time-start_time));
+
+ logger.writeToDisk();
+
+ session.dispose();
+ }
+
+}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/NurseryExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -9,8 +9,8 @@
import org.drools.compiler.PackageBuilder;
import org.drools.event.DebugAgendaEventListener;
import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.builder.DecisionTreeFactory;
import org.drools.learner.tools.ObjectFactory;
public class NurseryExample {
@@ -38,11 +38,11 @@
// instantiate a learner for a specific object class and pass session to train
//Learner learner = LearnerFactory.createID3(session, obj_class);
- Learner learner = LearnerFactory.createC45(session, obj_class);
+ DecisionTree dt_builder = DecisionTreeFactory.createBaggedC45(session, obj_class);
final PackageBuilder builder = new PackageBuilder();
//this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
+ builder.addPackageFromTree( dt_builder );
ruleBase.addPackage( builder.getPackage() );
/*
final Reader source = new InputStreamReader( HelloWorldExample.class.getResourceAsStream( "HelloWorld.drl" ) );
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/Poker.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -1,9 +1,9 @@
package org.drools.examples.learner;
+import org.drools.learner.tools.ClassAnnotation;
import org.drools.learner.tools.FieldAnnotation;
-
-
+ at ClassAnnotation(label_element = "getLabel")
public class Poker {
@FieldAnnotation(readingSeq = 0)
private int s1; // 'Suit of card #1': Ordinal (1-4) representing {Hearts, Spades, Diamonds, Clubs}
@@ -30,13 +30,29 @@
@FieldAnnotation(readingSeq = 9, discrete=false)
private int c5; // 'Rank of card #5': Numerical (1-13) representing (Ace, 2, 3, ... , Queen, King)
- @FieldAnnotation(readingSeq = 10, target = true)
+ @FieldAnnotation(readingSeq = 10, ignore = true)
private int poker_hand;
+ /*
+ *0: Nothing in hand; not a recognized poker hand
+ 1: One pair; one pair of equal ranks within five cards
+ 2: Two pairs; two pairs of equal ranks within five cards
+ 3: Three of a kind; three equal ranks within five cards
+ 4: Straight; five cards, sequentially ranked with no gaps
+ 5: Flush; five cards with the same suit
+ 6: Full house; pair + different rank three of a kind
+ 7: Four of a kind; four equal ranks within five cards
+ 8: Straight flush; straight + flush
+ 9: Royal flush; {Ace, King, Queen, Jack, Ten} + flush
+ */
public Poker() {
}
+ public boolean getLabel() {
+ return poker_hand>=4;
+ }
+
public int getS1() {
return s1;
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/PokerExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -5,17 +5,15 @@
import org.drools.RuleBase;
import org.drools.RuleBaseFactory;
import org.drools.StatefulSession;
-import org.drools.audit.WorkingMemoryFileLogger;
import org.drools.compiler.PackageBuilder;
-import org.drools.event.DebugAgendaEventListener;
-import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.builder.DecisionTreeFactory;
import org.drools.learner.tools.ObjectFactory;
public class PokerExample {
public static final void main(final String[] args) throws Exception {
+ long start_time = System.currentTimeMillis();
// my rule base
final RuleBase ruleBase = RuleBaseFactory.newRuleBase();
//ruleBase.addPackage( pkg );
@@ -23,12 +21,12 @@
final StatefulSession session = ruleBase.newStatefulSession();
// LearningSession
- // what are these listeners???
- session.addEventListener( new DebugAgendaEventListener() );
- session.addEventListener( new DebugWorkingMemoryEventListener() );
-
- final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
- logger.setFileName( "log/poker" );
+// // what are these listeners???
+// session.addEventListener( new DebugAgendaEventListener() );
+// session.addEventListener( new DebugWorkingMemoryEventListener() );
+//
+// final WorkingMemoryFileLogger logger = new WorkingMemoryFileLogger( session );
+// logger.setFileName( "log/poker" );
String inputFile = new String("data/poker/poker-hand-training-true.data.txt");
Class<?> obj_class = Poker.class;
@@ -38,13 +36,26 @@
}
// instantiate a learner for a specific object class and pass session to train
- Learner learner ;
- //learner = LearnerFactory.createC45(session, obj_class);
- learner = LearnerFactory.createC45fromBag(session, obj_class);
-
+ // instantiate a learner for a specific object class and pass session to train
+ DecisionTree decision_tree; int ALGO = 2;
+ switch (ALGO) {
+ case 1:
+ decision_tree = DecisionTreeFactory.createBaggedC45(session, obj_class);
+ break;
+ case 2:
+ decision_tree = DecisionTreeFactory.createBoostedC45(session, obj_class);
+ break;
+ case 3:
+ decision_tree = DecisionTreeFactory.createSingleID3(session, obj_class);
+ break;
+ default:
+ decision_tree = DecisionTreeFactory.createSingleC45(session, obj_class);
+
+ }
+
final PackageBuilder builder = new PackageBuilder();
//this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
+ builder.addPackageFromTree( decision_tree );
/*
* get the compiled package (which is serializable) from the builder
* add the package to a rulebase (deploy the rule package).
@@ -52,9 +63,11 @@
ruleBase.addPackage( builder.getPackage() );
session.fireAllRules();
+ long end_time = System.currentTimeMillis();
+ System.out.println("Total time="+ (end_time-start_time));
+
+// logger.writeToDisk();
- logger.writeToDisk();
-
session.dispose();
}
Modified: labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java
===================================================================
--- labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java 2008-06-09 18:26:19 UTC (rev 20384)
+++ labs/jbossrules/contrib/machinelearning/4.0.x/drools-examples/drools-examples-drl/src/main/java/org/drools/examples/learner/RestaurantExample.java 2008-06-09 18:37:43 UTC (rev 20385)
@@ -10,9 +10,10 @@
import org.drools.compiler.PackageBuilder;
import org.drools.event.DebugAgendaEventListener;
import org.drools.event.DebugWorkingMemoryEventListener;
-import org.drools.learner.builder.Learner;
-import org.drools.learner.builder.LearnerFactory;
+import org.drools.learner.DecisionTree;
+import org.drools.learner.builder.DecisionTreeFactory;
+
public class RestaurantExample {
public static final void main(final String[] args) throws Exception {
@@ -36,11 +37,11 @@
}
// instantiate a learner for a specific object class and pass session to train
- Learner learner = LearnerFactory.createID3(session, Restaurant.class);
+ DecisionTree dt_builder = DecisionTreeFactory.createSingleID3(session, Restaurant.class);
final PackageBuilder builder = new PackageBuilder();
//this wil generate the rules, then parse and compile in one step
- builder.addPackageFromLearner( learner );
+ builder.addPackageFromTree( dt_builder );
ruleBase.addPackage( builder.getPackage() );
/*
final Reader source = new InputStreamReader( HelloWorldExample.class.getResourceAsStream( "HelloWorld.drl" ) );
More information about the jboss-svn-commits
mailing list