00001 package edu.tum.cs.srl.bayesnets;
00002
00003 import java.io.FileNotFoundException;
00004 import java.io.IOException;
00005 import java.io.PrintStream;
00006 import java.util.Collection;
00007 import java.util.HashSet;
00008 import java.util.Set;
00009 import java.util.Vector;
00010 import java.util.regex.Matcher;
00011 import java.util.regex.Pattern;
00012
00013 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00014 import edu.ksu.cis.bnj.ver3.core.CPF;
00015 import edu.ksu.cis.bnj.ver3.core.Discrete;
00016 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00017 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00018 import edu.tum.cs.srl.RelationKey;
00019 import edu.tum.cs.srl.Signature;
00020 import edu.tum.cs.util.FileUtil;
00021 import edu.tum.cs.util.StringTool;
00022
00023 public class BLOGModel extends RelationalBeliefNetwork {
00024
00034 public BLOGModel(String[] blogFiles, String networkFile) throws Exception {
00035 super(networkFile);
00036
00037
00038 String blog = readBlogContent(blogFiles);
00039
00040
00041 Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/",
00042 Pattern.MULTILINE | Pattern.DOTALL);
00043 Matcher matcher = comments.matcher(blog);
00044 blog = matcher.replaceAll("");
00045
00046
00047 String[] lines = blog.split("\n");
00048 for (String line : lines) {
00049 line = line.trim();
00050 if (line.length() == 0)
00051 continue;
00052 if (!readDeclaration(line))
00053 if (!line.contains("~"))
00054 throw new Exception("Could not interpret the line '" + line
00055 + "'");
00056 }
00057
00058 checkSignatures();
00059 }
00060
00061 protected boolean readDeclaration(String line) throws Exception {
00062
00063
00064
00065 if (line.startsWith("random") || line.startsWith("logical")) {
00066 Pattern pat = Pattern.compile(
00067 "(random|logical)\\s+(\\w+)\\s+(\\w+)\\s*\\((.*)\\)\\s*;?",
00068 Pattern.CASE_INSENSITIVE);
00069 Matcher matcher = pat.matcher(line);
00070 if (matcher.matches()) {
00071 boolean isLogical = matcher.group(1).equals("logical");
00072 String retType = matcher.group(2);
00073 String[] argTypes = matcher.group(4).trim().split("\\s*,\\s*");
00074 Signature sig = new Signature(matcher.group(3), retType,
00075 argTypes, isLogical);
00076 addSignature(matcher.group(3), sig);
00077 return true;
00078 }
00079 return false;
00080 }
00081
00082 if (line.startsWith("guaranteed")) {
00083 Pattern pat = Pattern
00084 .compile("guaranteed\\s+(\\w+)\\s+(.*?)\\s*;?");
00085 Matcher matcher = pat.matcher(line);
00086 if (matcher.matches()) {
00087 String domName = matcher.group(1);
00088 String[] elems = matcher.group(2).split("\\s*,\\s*");
00089 guaranteedDomElements.put(domName, elems);
00090 return true;
00091 }
00092 return false;
00093 }
00094 return false;
00095 }
00096
00105 protected String readBlogContent(String[] files)
00106 throws FileNotFoundException, IOException {
00107
00108 StringBuffer buf = new StringBuffer();
00109 for (String blogFile : files) {
00110 buf.append(FileUtil.readTextFile(blogFile));
00111 buf.append('\n');
00112 }
00113 return buf.toString();
00114 }
00115
00124 public BLOGModel(String blogFile, String networkFile) throws Exception {
00125 this(new String[] { blogFile }, networkFile);
00126 }
00127
00135 public BLOGModel(String xmlbifFile) throws Exception {
00136 super(xmlbifFile);
00137 this.guessSignatures();
00138 }
00139
00147 public BeliefNetworkEx getGroundBN() throws Exception {
00148
00149 BeliefNetworkEx gbn = new BeliefNetworkEx();
00150
00151 int[] order = this.getTopologicalOrder();
00152 for (int i = 0; i < order.length; i++) {
00153
00154 RelationalNode node = getRelationalNode(order[i]);
00155
00156 Signature sig = getSignature(node.functionName);
00157 if (sig == null)
00158 throw new Exception("Could not retrieve signature for node "
00159 + node.functionName);
00160 Vector<String[]> argGroundings = groundParams(sig);
00161
00162
00163 for (String[] args : argGroundings) {
00164 String newName = Signature.formatVarName(node.functionName,
00165 args);
00166 BeliefNode newNode = new BeliefNode(newName, node.node
00167 .getDomain());
00168 gbn.addNode(newNode);
00169
00170 String[] parentNames = getParentVariableNames(node, args);
00171 for (String parentName : parentNames) {
00172 BeliefNode parent = gbn.getNode(parentName);
00173 gbn.bn.connect(parent, newNode);
00174 }
00175
00176
00177
00178
00179 CPF newCPF = newNode.getCPF(), oldCPF = node.node.getCPF();
00180 BeliefNode[] oldProd = oldCPF.getDomainProduct();
00181 BeliefNode[] newProd = newCPF.getDomainProduct();
00182 int[] old2newindex = new int[oldProd.length];
00183 for (int j = 0; j < oldProd.length; j++) {
00184 for (int k = 0; k < newProd.length; k++)
00185 if (RelationalNode.extractFunctionName(
00186 newProd[k].getName()).equals(
00187 RelationalNode.extractFunctionName(oldProd[j]
00188 .getName())))
00189 old2newindex[j] = k;
00190 }
00191 for (int j = 0; j < oldCPF.size(); j++) {
00192 int[] oldAddr = oldCPF.realaddr2addr(j);
00193 int[] newAddr = new int[oldAddr.length];
00194 for (int k = 0; k < oldAddr.length; k++)
00195 newAddr[old2newindex[k]] = oldAddr[k];
00196 newCPF.put(newCPF.addr2realaddr(newAddr), oldCPF.get(j));
00197 }
00198 }
00199 }
00200 return gbn;
00201 }
00202
00218 protected void groundParams(String[] domNames, String[] setting, int idx,
00219 Vector<String[]> ret) throws Exception {
00220 if (idx == domNames.length) {
00221 ret.add(setting.clone());
00222 return;
00223 }
00224 String[] elems = guaranteedDomElements.get(domNames[idx]);
00225 if (elems == null) {
00226 throw new Exception("No guaranteed domain elements for "
00227 + domNames[idx]);
00228 }
00229 for (String elem : elems) {
00230 setting[idx] = elem;
00231 groundParams(domNames, setting, idx + 1, ret);
00232 }
00233 }
00234
00235 protected Vector<String[]> groundParams(Signature sig) throws Exception {
00236 Vector<String[]> ret = new Vector<String[]>();
00237 groundParams(sig.argTypes, new String[sig.argTypes.length], 0, ret);
00238 return ret;
00239 }
00240
00241 public void write(PrintStream out) throws Exception {
00242 BeliefNode[] nodes = bn.getNodes();
00243
00244
00245
00246 writeDeclarations(out);
00247
00248
00249
00250 for (RelationalNode relNode : getRelationalNodes()) {
00251 if (relNode.isAuxiliary)
00252 continue;
00253 CPF cpf = nodes[relNode.index].getCPF();
00254 BeliefNode[] deps = cpf.getDomainProduct();
00255 Discrete[] domains = new Discrete[deps.length];
00256 StringBuffer args = new StringBuffer();
00257 int[] addr = new int[deps.length];
00258 for (int j = 0; j < deps.length; j++) {
00259 if (deps[j].getType() == BeliefNode.NODE_DECISION)
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270 continue;
00271 if (j > 0) {
00272 if (j > 1)
00273 args.append(", ");
00274 args.append(getRelationalNode(deps[j]).getCleanName());
00275 }
00276 domains[j] = (Discrete) deps[j].getDomain();
00277 }
00278 Vector<String> lists = new Vector<String>();
00279 getCPD(lists, cpf, domains, addr, 1);
00280 out.printf("%s ~ TabularCPD[%s](%s);\n", relNode.getCleanName(),
00281 StringTool.join(",", lists.toArray(new String[0])), args
00282 .toString());
00283 }
00284 }
00285
00286 protected void writeDeclarations(PrintStream out) {
00287
00288 Set<String> types = new HashSet<String>();
00289 for (RelationalNode node : this.getRelationalNodes()) {
00290 if (node.isBuiltInPred())
00291 continue;
00292 Signature sig = this.getSignature(node.functionName);
00293 Discrete domain = (Discrete) node.node.getDomain();
00294 if (!types.contains(sig.returnType)
00295 && !sig.returnType.equals("Boolean")) {
00296 if (!isBooleanDomain(domain)) {
00297 types.add(sig.returnType);
00298 out.printf("Type %s;\n", sig.returnType);
00299 } else
00300 sig.returnType = "Boolean";
00301 }
00302 for (String t : sig.argTypes) {
00303 if (!types.contains(t)) {
00304 types.add(t);
00305 out.printf("Type %s;\n", t);
00306 }
00307 }
00308 }
00309 out.println();
00310
00311
00312 Set<String> handledDomains = new HashSet<String>();
00313 for (RelationalNode node : this.getRelationalNodes()) {
00314 if (node.isBuiltInPred())
00315 continue;
00316 Discrete domain = (Discrete) node.node.getDomain();
00317 Signature sig = getSignature(node.functionName);
00318 if (!sig.returnType.equals("Boolean")) {
00319 String t = sig.returnType;
00320 if (!handledDomains.contains(t)) {
00321 handledDomains.add(t);
00322 out.print("guaranteed " + t + " ");
00323 for (int j = 0; j < domain.getOrder(); j++) {
00324 if (j > 0)
00325 out.print(", ");
00326 out.print(domain.getName(j));
00327 }
00328 out.println(";");
00329 }
00330 }
00331 }
00332 out.println();
00333
00334
00335 for (RelationalNode node : this.getRelationalNodes()) {
00336 if (node.isBuiltInPred())
00337 continue;
00338 Signature sig = getSignature(node.functionName);
00339 out.printf("random %s %s(%s);\n", sig.returnType,
00340 node.functionName, StringTool.join(", ", sig.argTypes));
00341 }
00342 out.println();
00343
00344
00345 for(Collection<RelationKey> c : this.relationKeys.values())
00346 for(RelationKey relKey : c)
00347 out.println(relKey.toString());
00348 out.println();
00349 }
00350
00351 protected void getCPD(Vector<String> lists, CPF cpf, Discrete[] domains,
00352 int[] addr, int i) {
00353 if (i == addr.length) {
00354 StringBuffer sb = new StringBuffer();
00355 sb.append('[');
00356 for (int j = 0; j < domains[0].getOrder(); j++) {
00357 addr[0] = j;
00358 int realAddr = cpf.addr2realaddr(addr);
00359 double value = ((ValueDouble) cpf.get(realAddr)).getValue();
00360 if (j > 0)
00361 sb.append(',');
00362 sb.append(value);
00363 }
00364 sb.append(']');
00365 lists.add(sb.toString());
00366 } else {
00367
00368 BeliefNode[] domProd = cpf.getDomainProduct();
00369 if (domProd[i].getType() == BeliefNode.NODE_DECISION)
00370
00371
00372
00373
00374
00375 addr[i] = 0;
00376 else {
00377 for (int j = 0; j < domains[i].getOrder(); j++) {
00378 addr[i] = j;
00379 getCPD(lists, cpf, domains, addr, i + 1);
00380 }
00381 }
00382 }
00383 }
00384 }