00001 package edu.tum.cs.bayesnets.conversion;
00002
00003 import java.io.PrintStream;
00004 import java.util.HashMap;
00005 import java.util.HashSet;
00006 import java.util.Random;
00007 import java.util.Map.Entry;
00008
00009 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00010 import edu.ksu.cis.bnj.ver3.core.Domain;
00011 import edu.ksu.cis.util.graph.core.Graph;
00012 import edu.ksu.cis.util.graph.core.Vertex;
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.learning.CPTLearner;
00015 import edu.tum.cs.bayesnets.learning.DomainLearner;
00016 import edu.tum.cs.srldb.Database;
00017 import edu.tum.cs.srldb.Object;
00018 import edu.tum.cs.srldb.datadict.DDAttribute;
00019 import edu.tum.cs.srldb.datadict.DDException;
00020 import edu.tum.cs.srldb.datadict.DDObject;
00021 import edu.tum.cs.srldb.datadict.DataDictionary;
00022 import edu.tum.cs.srldb.datadict.domain.AutomaticDomain;
00023 import edu.tum.cs.srldb.datadict.domain.BooleanDomain;
00024
00025 public class BN2SRLDB {
00026 protected BeliefNetworkEx bn;
00027 protected Database db;
00028 protected HashSet<String> booleanConversion;
00029 protected HashMap<String,String> undoConversion;
00030
00031 public BN2SRLDB(BeliefNetworkEx bn) {
00032 this.bn = bn;
00033 this.db = null;
00034 this.booleanConversion = null;
00035 }
00036
00037 public void setBooleanConversion(String attrName) {
00038 if(booleanConversion == null) {
00039 booleanConversion = new HashSet<String>();
00040 }
00041 booleanConversion.add(attrName);
00042 }
00043
00044 public Database getDB(int numSamples) throws DDException, Exception {
00045 return getDB(numSamples, new Random());
00046 }
00047
00048 protected boolean isBooleanNode(BeliefNode node) {
00049 Domain nodeDomain = node.getDomain();
00050 return nodeDomain.getOrder() == 2 && (nodeDomain.getName(0).equalsIgnoreCase("true") || nodeDomain.getName(0).equalsIgnoreCase("false"));
00051 }
00052
00053 public Database getDB(int numSamples, Random generator) throws DDException, Exception {
00054
00055 DataDictionary datadict = new DataDictionary();
00056 DDObject ddObj = new DDObject(Object.class.getSimpleName());
00057
00058 BeliefNode[] nodes = bn.bn.getNodes();
00059 for(int i = 0; i < nodes.length; i++) {
00060 edu.tum.cs.srldb.datadict.domain.Domain domain;
00061
00062 if(isBooleanNode(nodes[i])) {
00063 domain = BooleanDomain.getInstance();
00064 ddObj.addAttribute(new DDAttribute(nodes[i].getName(), domain));
00065 }
00066 else {
00067 String name = nodes[i].getName();
00068
00069 domain = new AutomaticDomain("dom" + nodes[i].getName());
00070 DDAttribute ddAttr = new DDAttribute(nodes[i].getName(), domain);
00071 ddObj.addAttribute(ddAttr);
00072
00073 if(booleanConversion != null && booleanConversion.contains(name)) {
00074
00075 Domain nodeDomain = nodes[i].getDomain();
00076 for(int j = 0; j < nodeDomain.getOrder(); j++) {
00077 ddObj.addAttribute(new DDAttribute(nodeDomain.getName(j), BooleanDomain.getInstance()));
00078 }
00079
00080
00081 ddAttr.discard();
00082 }
00083 }
00084 }
00085 datadict.addObject(ddObj);
00086
00087
00088 db = new Database(datadict);
00089
00090
00091 for(int i = 0; i < numSamples; i++) {
00092 HashMap<String,String> sample = bn.getSample(generator);
00093 Object obj = new Object(db, "object");
00094 if(booleanConversion != null) {
00095 for(String attrName : booleanConversion) {
00096 String value = sample.get(attrName);
00097 sample.put(value, "true");
00098
00099 }
00100 }
00101 obj.addAttributes(sample);
00102 obj.commit();
00103 System.out.println(sample);
00104 }
00105
00106 db.check();
00107
00108 return db;
00109 }
00110
00111 public void relearnBN() throws Exception {
00112 if(db == null)
00113 throw new Exception("No sampled data available for learning; call getDB() first!");
00114
00115 CPTLearner cptLearner = new CPTLearner(bn);
00116 for(Object obj : db.getObjects()) {
00117 cptLearner.learn(obj.getAttributes());
00118 }
00119 cptLearner.finish();
00120 }
00121
00122 protected void writeNodeLiteralAllCombs(PrintStream out, BeliefNode n, int varidx) {
00123 if(isBooleanNode(n))
00124 out.print("*" + Database.stdPredicateName(n.getName()) + "(o)");
00125 else
00126 out.print(Database.stdPredicateName(n.getName()) + "(o,+a" + varidx + ")");
00127 }
00128
00129 public void writeMLNFormulas(PrintStream out) {
00130 Graph g = bn.bn.getGraph();
00131 Vertex[] vertices = g.getVertices();
00132 BeliefNode[] nodes = bn.bn.getNodes();
00133 for(int i = 0; i < vertices.length; i++) {
00134 Vertex[] parents = g.getParents(vertices[i]);
00135 if(parents.length == 0)
00136 continue;
00137 int varidx = 0;
00138 for(int j = 0; j < parents.length; j++) {
00139 BeliefNode n = nodes[parents[j].loc()];
00140 if(j > 0)
00141 out.print(" ^ ");
00142 writeNodeLiteralAllCombs(out, n, varidx++);
00143 }
00144
00145 out.print(" => ");
00146 writeNodeLiteralAllCombs(out, nodes[vertices[i].loc()], varidx);
00147 out.println();
00148 }
00149 }
00150 }