00001 package edu.tum.cs.srl.bayesnets;
00002
00003 import java.io.PrintStream;
00004 import java.util.HashMap;
00005 import java.util.Map;
00006 import java.util.Set;
00007 import java.util.TreeSet;
00008 import java.util.Vector;
00009
00010 import weka.classifiers.trees.J48;
00011 import weka.classifiers.trees.j48.Rule;
00012 import weka.classifiers.trees.j48.Rule.Condition;
00013 import weka.core.Attribute;
00014 import weka.core.FastVector;
00015 import weka.core.Instance;
00016 import weka.core.Instances;
00017 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00018 import edu.ksu.cis.bnj.ver3.core.CPF;
00019 import edu.ksu.cis.bnj.ver3.core.Discrete;
00020 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00021 import edu.tum.cs.logic.Conjunction;
00022 import edu.tum.cs.logic.Formula;
00023
00024 public class CPT2MLNFormulas {
00025 protected RelationalBeliefNetwork rbn;
00026 protected RelationalNode mainNode;
00027 protected String additionalPrecondition;
00028
00029 public CPT2MLNFormulas(RelationalNode node) {
00030 this.mainNode = node;
00031 this.rbn = node.getNetwork();
00032 additionalPrecondition = null;
00033 }
00034
00039 public void addPrecondition(String cond) {
00040 if(additionalPrecondition == null)
00041 additionalPrecondition = cond;
00042 else
00043 additionalPrecondition += " ^ " + cond;
00044 }
00045
00046 public String getPrecondition() {
00047 return additionalPrecondition;
00048 }
00049
00054 public void convert(PrintStream out) {
00055 try {
00056 CPT2Rules cpt2rules = new CPT2Rules(mainNode);
00057
00058
00059
00060 for(HashMap<String, String> constantAssignment : mainNode.getConstantAssignments()) {
00061 Rule[] rules = cpt2rules.learnRules(constantAssignment);
00062
00063
00064 for(Rule rule : rules) {
00065 if(!rule.hasAntds())
00066 continue;
00067 StringBuffer conjunction = new StringBuffer();
00068 int lits = 0;
00069
00070
00071 boolean haveMainNode = false;
00072 for(Condition c : rule.getAntecedent()) {
00073 RelationalNode node = cpt2rules.getRelationalNode(c);
00074 if(node == mainNode)
00075 haveMainNode = true;
00076 String literal = node.toLiteralString(rbn.getDomainIndex(node.node, c.getValue()), constantAssignment);
00077 if(lits++ > 0)
00078 conjunction.append(" ^ ");
00079 conjunction.append(literal);
00080 }
00081
00082
00083 for(RelationalNode parent : rbn.getRelationalParents(mainNode)) {
00084 if(parent.isPrecondition) {
00085 if(lits++ > 0)
00086 conjunction.append(" ^ ");
00087 conjunction.append(parent.toLiteralString(rbn.getDomainIndex(parent.node, "True"), constantAssignment));
00088 }
00089 }
00090
00091 Vector<String> conjunctions = new Vector<String>();
00092 if(!haveMainNode) {
00093 for(int i = 0; i < mainNode.node.getDomain().getOrder(); i++) {
00094 conjunctions.add(conjunction.toString() + " ^ " + mainNode.toLiteralString(i, null));
00095 }
00096 }
00097 else
00098 conjunctions.add(conjunction.toString());
00099
00100 double prob = Double.parseDouble(rule.getConsequent().getValue());
00101 double weight = prob == 0.0 ? -100 : Math.log(prob);
00102 for(String conj : conjunctions) {
00103 out.print(weight + " ");
00104 out.print(conj);
00105 if(additionalPrecondition != null)
00106 out.print(" ^ " + additionalPrecondition);
00107 out.println();
00108 }
00109 }
00110 }
00111 }
00112 catch (Exception e) {
00113 e.printStackTrace();
00114 }
00115 }
00116
00123 public static class CPT2Rules {
00127 protected HashMap<String, Attribute> attrs;
00128 protected RelationalBeliefNetwork rbn;
00129 protected CPF cpf;
00130 BeliefNode[] nodes;
00131 FastVector fvAttribs;
00132 HashMap<Attribute, RelationalNode> relNodes;
00133 RelationalNode mainNode;
00134 int zerosInCPT;
00135
00136 public CPT2Rules(RelationalNode relNode) {
00137 mainNode = relNode;
00138 rbn = relNode.getNetwork();
00139 cpf = relNode.node.getCPF();
00140 nodes = cpf.getDomainProduct();
00141
00142
00143 fvAttribs = new FastVector(nodes.length+1);
00144 attrs = new HashMap<String,Attribute>();
00145
00146
00147 relNodes = new HashMap<Attribute, RelationalNode>();
00148 for(BeliefNode node : nodes) {
00149 ExtendedNode extNode = rbn.getExtendedNode(node);
00150 if(extNode instanceof DecisionNode)
00151 continue;
00152 Discrete dom = (Discrete)node.getDomain();
00153 FastVector attValues = new FastVector(dom.getOrder());
00154 for(int i = 0; i < dom.getOrder(); i++)
00155 attValues.addElement(dom.getName(i));
00156 Attribute attr = new Attribute(node.getName(), attValues);
00157 attrs.put(node.getName(), attr);
00158 relNodes.put(attr, rbn.getRelationalNode(node));
00159 fvAttribs.addElement(attr);
00160 }
00161
00162
00163
00164 TreeSet<Double> probs = new TreeSet<Double>();
00165 zerosInCPT = 0;
00166 walkCPT4ValueSet(new int[nodes.length], 0, probs);
00167 FastVector attrValues = new FastVector(probs.size());
00168 for(Double d : probs)
00169 attrValues.addElement(Double.toString(d));
00170
00171 Attribute probAttr = new Attribute("prob", attrValues);
00172 attrs.put("prob", probAttr);
00173 fvAttribs.addElement(probAttr);
00174 }
00175
00176 protected void walkCPT4ValueSet(int[] addr, int i, Set<Double> values) {
00177 BeliefNode[] nodes = cpf.getDomainProduct();
00178 if(i == addr.length) {
00179
00180 int realAddr = cpf.addr2realaddr(addr);
00181 double value = ((ValueDouble)cpf.get(realAddr)).getValue();
00182 if(value == 0.0)
00183 zerosInCPT++;
00184 values.add(value);
00185 }
00186 else {
00187 Discrete dom = (Discrete)nodes[i].getDomain();
00188 ExtendedNode extNode = rbn.getExtendedNode(nodes[i]);
00189 if(extNode instanceof DecisionNode) {
00190 addr[i] = 0;
00191 walkCPT4ValueSet(addr, i+1, values);
00192 }
00193 else {
00194 RelationalNode n = (RelationalNode)extNode;
00195 if(n.isPrecondition) {
00196 addr[i] = dom.findName("True");
00197 walkCPT4ValueSet(addr, i+1, values);
00198 }
00199 else {
00200 for(int j = 0; j < dom.getOrder(); j++) {
00201 addr[i] = j;
00202 walkCPT4ValueSet(addr, i+1, values);
00203 }
00204 }
00205 }
00206 }
00207 }
00208
00209 public int getZerosInCPT() {
00210 return zerosInCPT;
00211 }
00212
00219 public Rule[] learnRules(Map<String, String> constantAssignment) throws Exception {
00220
00221 Instances instances = new Instances("foo", fvAttribs, 60000);
00222 walkCPT4InstanceCollection(new int[nodes.length], 0, constantAssignment, instances);
00223
00224
00225 instances.setClass(attrs.get("prob"));
00226 J48 j48 = new J48();
00227 j48.setUnpruned(true);
00228 j48.setMinNumObj(0);
00229 j48.buildClassifier(instances);
00230
00231
00232
00233
00234 return j48.getRules();
00235 }
00236
00237 protected void walkCPT4InstanceCollection(int[] addr, int i, Map<String,String> constantSettings, Instances instances) throws Exception {
00238 BeliefNode[] nodes = cpf.getDomainProduct();
00239 if(i == addr.length) {
00240
00241 int realAddr = cpf.addr2realaddr(addr);
00242 double value = ((ValueDouble)cpf.get(realAddr)).getValue();
00243
00244
00245 Instance inst = new Instance(nodes.length+1);
00246
00247
00248 for(int j = 0; j < addr.length; j++) {
00249 Attribute attr = attrs.get(nodes[j].getName());
00250 if(attr != null) {
00251 Discrete dom = (Discrete)nodes[j].getDomain();
00252 inst.setValue(attr, dom.getName(addr[j]));
00253 }
00254 }
00255
00256 inst.setValue(attrs.get("prob"), Double.toString(value));
00257
00258
00259 instances.add(inst);
00260 }
00261 else {
00262 Discrete dom = (Discrete)nodes[i].getDomain();
00263 ExtendedNode extNode = rbn.getExtendedNode(nodes[i]);
00264 if(extNode instanceof DecisionNode) {
00265 addr[i] = 0;
00266 walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
00267 }
00268 else {
00269 RelationalNode n = (RelationalNode)extNode;
00270 if(n.isPrecondition) {
00271 addr[i] = dom.findName("True");
00272 if(addr[i] == -1)
00273 throw new Exception("The node " + nodes[i] + " is set as a precondition, but its domain does not contain the value 'True'.");
00274 walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
00275 }
00276 else if(n.isConstant) {
00277 addr[i] = dom.findName(constantSettings.get(n.getName()));
00278 walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
00279 }
00280 else {
00281 for(int j = 0; j < dom.getOrder(); j++) {
00282 addr[i] = j;
00283 walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
00284 }
00285 }
00286 }
00287 }
00288 }
00289
00295 public RelationalNode getRelationalNode(Condition c) {
00296 return relNodes.get(c.getAttribute());
00297 }
00298
00299 public Formula getConjunction(Rule rule, Map<String,String> constantAssignment) throws Exception {
00300 boolean haveMainNode = false;
00301 Vector<Formula> conjuncts = new Vector<Formula>();
00302 for(Condition c : rule.getAntecedent()) {
00303 RelationalNode node = this.getRelationalNode(c);
00304 if(node == mainNode)
00305 haveMainNode = true;
00306 int value = rbn.getDomainIndex(node.node, c.getValue());
00307 Formula literal = node.toLiteral(value, constantAssignment);
00308 conjuncts.add(literal);
00309 }
00310 return new Conjunction(conjuncts);
00311 }
00312
00313 public double getProbability(Rule r) {
00314 return Double.parseDouble(r.getConsequent().getValue());
00315 }
00316 }
00317 }