00001 import java.io.FileNotFoundException;
00002 import java.util.HashMap;
00003 import java.util.Set;
00004 import java.util.TreeSet;
00005 import java.util.Vector;
00006
00007 import weka.classifiers.trees.J48;
00008 import weka.core.Attribute;
00009 import weka.core.FastVector;
00010 import weka.core.Instance;
00011 import weka.core.Instances;
00012
00013 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00014 import edu.ksu.cis.bnj.ver3.core.CPF;
00015 import edu.ksu.cis.bnj.ver3.core.Discrete;
00016 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00017 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00018 import edu.tum.cs.srl.bayesnets.CPT2MLNFormulas;
00019 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00020
00021
00022 public class CPTDecTree {
00023
00027 public static void main(String[] args) {
00028 try {
00029 RelationalBeliefNetwork bn = new RelationalBeliefNetwork("actseq.xml");
00030
00031 CPT2MLNFormulas cpt2mln = new CPT2MLNFormulas(bn.getRelationalNode("performs(p,a2,s2)"));
00032 cpt2mln.convert(System.out);
00033 System.exit(0);
00034
00035 FastVector fvAttribs = new FastVector(bn.bn.getNodes().length+1);
00036 HashMap<String, Attribute> attrs = new HashMap<String, Attribute>();
00037 for(BeliefNode node : bn.bn.getNodes()) {
00038 Discrete dom = (Discrete)node.getDomain();
00039 FastVector attValues = new FastVector(dom.getOrder());
00040 for(int i = 0; i < dom.getOrder(); i++)
00041 attValues.addElement(dom.getName(i));
00042 Attribute attr = new Attribute(node.getName(), attValues);
00043 attrs.put(node.getName(), attr);
00044 fvAttribs.addElement(attr);
00045 }
00046
00047 BeliefNode node = bn.getNode("performs(p,a2,s2)");
00048
00049
00050
00051 CPF cpf = node.getCPF();
00052 TreeSet<Double> probs = new TreeSet<Double>();
00053 walkCPT2(cpf, new int[bn.bn.getNodes().length], 0, probs);
00054 FastVector attrValues = new FastVector(probs.size());
00055 for(Double d : probs)
00056 attrValues.addElement(Double.toString(d));
00057
00058 Attribute probAttr = new Attribute("prob", attrValues);
00059 attrs.put("prob", probAttr);
00060 fvAttribs.addElement(probAttr);
00061
00062 Instances instances = new Instances(node.getName(), fvAttribs, 60000);
00063
00064
00065
00066 walkCPT(node.getCPF(), new int[bn.bn.getNodes().length], 0, instances, attrs);
00067
00068
00069 instances.setClass(attrs.get("prob"));
00070 J48 j48 = new J48();
00071 j48.setUnpruned(true);
00072 j48.setMinNumObj(0);
00073 j48.buildClassifier(instances);
00074
00075
00076 System.out.println(j48);
00077 }
00078 catch (Exception e) {
00079 e.printStackTrace();
00080 }
00081 }
00082
00083 public static void walkCPT(CPF cpf, int[] addr, int i, Instances instances, HashMap<String, Attribute> attrs) {
00084 BeliefNode[] nodes = cpf.getDomainProduct();
00085 if(i == addr.length) {
00086
00087 int realAddr = cpf.addr2realaddr(addr);
00088 double value = ((ValueDouble)cpf.get(realAddr)).getValue();
00089
00090
00091 Instance inst = new Instance(nodes.length+1);
00092
00093
00094 for(int j = 0; j < addr.length; j++) {
00095 Discrete dom = (Discrete)nodes[j].getDomain();
00096 inst.setValue(attrs.get(nodes[j].getName()), dom.getName(addr[j]));
00097 }
00098
00099 inst.setValue(attrs.get("prob"), Double.toString(value));
00100
00101
00102 instances.add(inst);
00103 }
00104 else {
00105 Discrete dom = (Discrete)nodes[i].getDomain();
00106 for(int j = 0; j < dom.getOrder(); j++) {
00107 addr[i] = j;
00108 walkCPT(cpf, addr, i+1, instances, attrs);
00109 }
00110 }
00111 }
00112
00113 public static void walkCPT2(CPF cpf, int[] addr, int i, Set<Double> values) {
00114 BeliefNode[] nodes = cpf.getDomainProduct();
00115 if(i == addr.length) {
00116
00117 int realAddr = cpf.addr2realaddr(addr);
00118 double value = ((ValueDouble)cpf.get(realAddr)).getValue();
00119 values.add(value);
00120 }
00121 else {
00122 Discrete dom = (Discrete)nodes[i].getDomain();
00123 for(int j = 0; j < dom.getOrder(); j++) {
00124 addr[i] = j;
00125 walkCPT2(cpf, addr, i+1, values);
00126 }
00127 }
00128 }
00129 }