00001 package edu.tum.cs.srl.bayesnets.learning;
00002
00003 import java.util.HashSet;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00006 import edu.ksu.cis.bnj.ver3.core.Discrete;
00007 import edu.tum.cs.srl.Database;
00008 import edu.tum.cs.srl.Signature;
00009 import edu.tum.cs.srl.bayesnets.DecisionNode;
00010 import edu.tum.cs.srl.bayesnets.ExtendedNode;
00011 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00012 import edu.tum.cs.srl.bayesnets.RelationalNode;
00013
00014 public class DomainLearner extends edu.tum.cs.bayesnets.learning.DomainLearner {
00015 public DomainLearner(RelationalBeliefNetwork bn) {
00016 super(bn);
00017 }
00018
00019 public void learn(Database db) throws Exception {
00020
00021 boolean debug = false;
00022 RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn;
00023 BeliefNode[] nodes = bn.bn.getNodes();
00024 for(int i = 0; i < nodes.length; i++) {
00025 ExtendedNode extNode = bn.getExtendedNode(i);
00026 boolean mustApplyBooleanDomain = false;
00027
00028 if(extNode instanceof DecisionNode)
00029 mustApplyBooleanDomain = true;
00030
00031 if(extNode instanceof RelationalNode) {
00032 RelationalNode node = (RelationalNode)extNode;
00033
00034 mustApplyBooleanDomain = node.isBuiltInPred();
00035 if(!mustApplyBooleanDomain) {
00036 if(debug) System.out.println("node: " + node);
00037 Signature sig = bn.getSignature(node.getFunctionName());
00038 if(sig == null) {
00039 throw new Exception("Could not obtain signature of " + node.getFunctionName());
00040 }
00041 if(sig.isBoolean())
00042 mustApplyBooleanDomain = true;
00043 else {
00044 Iterable<String> values = db.getDomain(sig.returnType);
00045 if(values == null) {
00046 db.printDomain(System.out);
00047 throw new Exception("Domain '" + sig.returnType + "' of node '" + nodes[i].getName() + "' has no values in the database.");
00048 }
00049 for(String value : values) {
00050 if(debug) System.out.println("adding " + value + " to " + sig.returnType + " while processing " + sig.functionName + " - returnType = " + sig.returnType);
00051 ((HashSet<String>)directDomainData[i]).add(value);
00052 }
00053 continue;
00054 }
00055 }
00056 }
00057 if(mustApplyBooleanDomain) {
00058 ((HashSet<String>)directDomainData[i]).add("True");
00059 ((HashSet<String>)directDomainData[i]).add("False");
00060 continue;
00061 }
00062 }
00063 }
00064
00065 protected void end_learning() throws Exception {
00066 super.end_learning();
00067
00068
00069 Discrete booleanDomain = new Discrete(new String[]{"True", "False"});
00070 RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn;
00071 BeliefNode[] nodes = bn.bn.getNodes();
00072 for(int i = 0; i < nodes.length; i++) {
00073 System.out.println(nodes[i].getName());
00074 if(RelationalBeliefNetwork.isBooleanDomain((Discrete)nodes[i].getDomain())) {
00075 ExtendedNode extNode = bn.getExtendedNode(i);
00076 if(extNode instanceof RelationalNode) {
00077 Signature sig = bn.getSignature((RelationalNode)extNode);
00078 if(sig != null)
00079 sig.returnType = "Boolean";
00080 }
00081 bn.bn.changeBeliefNodeDomain(nodes[i], booleanDomain);
00082 }
00083 }
00084 }
00085 }