00001 package edu.tum.cs.srl.bayesnets.learning;
00002
00003 import java.util.Collection;
00004 import java.util.HashMap;
00005 import java.util.Map;
00006 import java.util.Vector;
00007 import java.util.Map.Entry;
00008
00009 import edu.ksu.cis.bnj.ver3.core.Discrete;
00010 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00011 import edu.tum.cs.srl.Database;
00012 import edu.tum.cs.srl.Signature;
00013 import edu.tum.cs.srl.bayesnets.DecisionNode;
00014 import edu.tum.cs.srl.bayesnets.ExtendedNode;
00015 import edu.tum.cs.srl.bayesnets.ParentGrounder;
00016 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00017 import edu.tum.cs.srl.bayesnets.RelationalNode;
00018 import edu.tum.cs.srl.bayesnets.RelationalNode.Aggregator;
00019 import edu.tum.cs.util.StringTool;
00020
00021 public class CPTLearner extends edu.tum.cs.bayesnets.learning.CPTLearner {
00022
00023 protected HashMap<Integer, HashMap<String, Integer>> marginals;
00024 protected int numExamples;
00025 protected boolean verbose;
00026 protected boolean debug = false;
00027
00028 public CPTLearner(RelationalBeliefNetwork bn) throws Exception {
00029 this(bn, false, false);
00030 }
00031
00032 public CPTLearner(RelationalBeliefNetwork bn, boolean uniformDefault, boolean debug) throws Exception {
00033 super(bn);
00034 setUniformDefault(uniformDefault);
00035 this.debug = debug;
00036
00037 }
00038
00047 protected void countVariable(Database db, RelationalNode node, String[] params, boolean closedWorld) throws Exception {
00048
00049 if(!node.hasCPT())
00050 return;
00051 RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn;
00052
00053 ExampleCounter counter = this.counters[node.index];
00054
00055 String varName = Signature.formatVarName(node.getFunctionName(), params);
00056
00057
00058
00059 ParentGrounder pg = bn.getParentGrounder(node);
00060 Vector<Map<Integer, String[]>> groundings = pg.getGroundings(params, db);
00061 if(groundings == null) {
00062 if(debug)
00063 System.err.println("Variable " + Signature.formatVarName(node.getFunctionName(), params)+ " skipped because parents could not be grounded.");
00064 return;
00065 }
00066
00067
00068
00069
00070
00071
00072
00073
00074 double exampleWeight = 1.0;
00075
00076
00077 if(false) {
00078
00079
00080 if(node.aggregator == Aggregator.Average && node.parentMode != null && node.parentMode.equals("CP")) {
00081
00082
00083 int dim = 1;
00084 Vector<Integer> relevantParentIndices = new Vector<Integer>();
00085 Vector<Integer> precondParentIndices = new Vector<Integer>();
00086 for(RelationalNode parent : bn.getRelationalParents(node)) {
00087 if(parent.isPrecondition) {
00088 precondParentIndices.add(parent.index);
00089 continue;
00090 }
00091 dim *= parent.getDomain().getOrder();
00092 relevantParentIndices.add(parent.index);
00093 }
00094 double[] v = new double[dim];
00095
00096 int numExamples = 0;
00097 for(Map<Integer, String[]> paramSets : groundings) {
00098 boolean skip = false;
00099
00100 for(Integer nodeIdx : precondParentIndices) {
00101 RelationalNode ndCurrent = bn.getRelationalNode(nodeIdx);
00102 String value = db.getVariableValue(ndCurrent.getVariableName(paramSets.get(ndCurrent.index)), closedWorld);
00103 if(!value.equalsIgnoreCase("true")) {
00104 skip = true;
00105 break;
00106 }
00107 }
00108 if(skip)
00109 continue;
00110
00111 int factor = 1;
00112 int addr = 0;
00113 for(Integer nodeIdx : relevantParentIndices) {
00114 RelationalNode ndCurrent = bn.getRelationalNode(nodeIdx);
00115
00116 String value = ndCurrent.getValueInDB(paramSets.get(ndCurrent.index), db, closedWorld);
00117 Discrete dom = ndCurrent.getDomain();
00118 int domIdx = dom.findName(value);
00119 if(domIdx < 0) {
00120 String[] domain = BeliefNetworkEx.getDiscreteDomainAsArray(ndCurrent.node);
00121 throw new Exception("Could not find value '" + value + "' in domain of " + ndCurrent.toString() + " {" + StringTool.join(",", domain) + "}");
00122 }
00123 addr += factor * domIdx;
00124 factor *= dom.getOrder();
00125 }
00126 v[addr] += 1;
00127 numExamples++;
00128 }
00129
00130 for(int i = 0; i < v.length; i++)
00131 v[i] = v[i] / numExamples;
00132
00133 exampleWeight = 0;
00134 int exponent = 10;
00135 for(int i = 0; i < v.length; i++) {
00136 exampleWeight += Math.pow(v[i], exponent);
00137 }
00138
00139 }
00140 }
00141
00142
00143
00144 for(Map<Integer, String[]> paramSets : groundings) {
00145
00146
00147 boolean countExample = true;
00148
00149 for(int i = 1; i < counter.nodeIndices.length; i++) {
00150 ExtendedNode extCurrent = bn.getExtendedNode(counter.nodeIndices[i]);
00151 if(!(extCurrent instanceof RelationalNode))
00152 continue;
00153 RelationalNode ndCurrent = (RelationalNode)extCurrent;
00154 if(ndCurrent.isPrecondition) {
00155 String[] actualParams = paramSets.get(ndCurrent.index);
00156 String value = ndCurrent.getValueInDB(actualParams, db, closedWorld);
00157
00158 if(!value.equalsIgnoreCase("true")) {
00159 countExample = false;
00160 break;
00161 }
00162 }
00163 }
00164
00165 if(!countExample)
00166 continue;
00167
00168
00169 int domainIndices[] = new int[this.nodes.length];
00170 for(int i = 0; i < counter.nodeIndices.length; i++) {
00171 int domain_idx = -1;
00172 ExtendedNode extCurrent = bn.getExtendedNode(counter.nodeIndices[i]);
00173
00174
00175 if(extCurrent instanceof DecisionNode) {
00176 domain_idx = 0;
00177 }
00178
00179 else {
00180
00181 RelationalNode ndCurrent = (RelationalNode)extCurrent;
00182
00183 if(ndCurrent.isPrecondition) {
00184 domainIndices[extCurrent.index] = 0;
00185 continue;
00186 }
00187
00188 String[] actualParams = paramSets.get(ndCurrent.index);
00189 if(actualParams == null) {
00190 Vector<String> availableNodes = new Vector<String>();
00191 for(Integer idx : paramSets.keySet())
00192 availableNodes.add(idx.toString() + "/" + ndCurrent.getNetwork().getRelationalNode(idx).toString());
00193 throw new Exception("Relevant node " + ndCurrent.index + "/" + ndCurrent + " has no grounding for main node instantiation " + varName + "; have only " + availableNodes.toString());
00194 }
00195 String value = ndCurrent.getValueInDB(actualParams, db, closedWorld);
00196 if(value == null)
00197 throw new Exception(String.format("Could not find setting for node named '%s' while processing '%s'", ndCurrent.getName(), varName));
00198
00199 Discrete dom = (Discrete)(ndCurrent.node.getDomain());
00200 domain_idx = dom.findName(value);
00201 if(domain_idx == -1) {
00202 String[] domElems = new String[dom.getOrder()];
00203 for(int j = 0; j < domElems.length; j++)
00204 domElems[j] = dom.getName(j);
00205 throw new Exception(String.format("'%s' not found in domain of %s {%s} while processing %s", value, ndCurrent.getFunctionName(), StringTool.join(",", domElems), varName));
00206 }
00207
00208 if(ndCurrent.isConstant) {
00209 int[] constantDomainIndices = new int[this.nodes.length];
00210 constantDomainIndices[ndCurrent.index] = domain_idx;
00211 this.counters[ndCurrent.index].count(constantDomainIndices);
00212 }
00213 }
00214 domainIndices[extCurrent.index] = domain_idx;
00215 }
00216
00217
00218 counter.count(domainIndices, exampleWeight);
00219 numExamples++;
00220 if(debug && verbose) {
00221 StringBuffer condition = new StringBuffer();
00222 for(Entry<Integer, String[]> e : paramSets.entrySet()) {
00223 if(e.getKey() == node.index)
00224 continue;
00225 RelationalNode rn = bn.getRelationalNode(e.getKey());
00226 condition.append(' ');
00227 condition.append(rn.getVariableName(e.getValue()));
00228 condition.append('=');
00229 condition.append(rn.getDomain().getName(domainIndices[rn.index]));
00230 }
00231 System.out.println(" " + node.getVariableName(params) + "=" + node.getDomain().getName(domainIndices[node.index]) + " |" + condition);
00232 }
00233
00234
00235
00236
00237
00238
00239
00240 }
00241 }
00242
00248 @Deprecated
00249 public void learn(Database db) throws Exception {
00250 throw new Exception("No longer supported");
00251
00252
00253
00254 }
00255
00263 public void learnTyped(Database db, boolean closedWorld, boolean verbose) throws Exception {
00264 this.verbose = verbose;
00265 RelationalBeliefNetwork bn = (RelationalBeliefNetwork)this.bn;
00266
00267
00268
00269 for(RelationalNode node : bn.getRelationalNodes()) {
00270 if(node.isConstant || node.isBuiltInPred() || !node.hasCPT())
00271 continue;
00272 node.getParentGrounder();
00273 }
00274
00275
00276 for(RelationalNode node : bn.getRelationalNodes()) {
00277 if(node.isConstant || node.isBuiltInPred())
00278 continue;
00279 numExamples = 0;
00280 if(verbose)
00281 System.out.println(" " + node.getName());
00282
00283 String[] params = new String[node.params.length];
00284 countVariable(db, node, params, bn.getSignature(node.getFunctionName()).argTypes, 0, closedWorld);
00285 if(verbose)
00286 System.out.println(" " + numExamples + " counted");
00287
00288 }
00289 }
00290
00301 protected void countVariable(Database db, RelationalNode node, String[] params, String[] domainNames, int i, boolean closedWorld) throws Exception {
00302
00303 if(i == params.length) {
00304
00305 if(!closedWorld) {
00306 String varName = Signature.formatVarName(node.getFunctionName(), params);
00307 if(!db.contains(varName))
00308 throw new Exception("Incomplete data: No value for " + varName);
00309 }
00310
00311
00312
00313
00314 Collection<DecisionNode> decisions = node.getDecisionParents();
00315 if(decisions.size() > 0) {
00316 for(DecisionNode decision : decisions) {
00317 if(!decision.isTrue(node.params, params, db, closedWorld))
00318 return;
00319 }
00320 }
00321
00322 countVariable(db, node, params, closedWorld);
00323 return;
00324 }
00325
00326
00327 if(RelationalNode.isConstant(node.params[i])) {
00328 params[i] = node.params[i];
00329 countVariable(db, node, params, domainNames, i+1, closedWorld);
00330 }
00331 else {
00332 Iterable<String> domain = db.getDomain(domainNames[i]);
00333 for(String element : domain) {
00334 params[i] = element;
00335 countVariable(db, node, params, domainNames, i+1, closedWorld);
00336 }
00337 }
00338 }
00339 }