00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.Arrays;
00004 import java.util.HashMap;
00005 import java.util.Map;
00006 import java.util.Random;
00007
00008 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00009 import edu.ksu.cis.bnj.ver3.core.Domain;
00010 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00011 import edu.tum.cs.bayesnets.core.Discretized;
00012
00020 public class WeightedSample {
00021 BeliefNetworkEx bn;
00026 public int[] nodeDomainIndices;
00031 public int[] nodeIndices;
00035 public double weight;
00039 public int trials;
00043 public int operations;
00044
00061 public WeightedSample(BeliefNetworkEx bn, int[] nodeDomainIndices, double weight, int[] nodeIndices, int trials) {
00062 this.bn = bn;
00063 if (nodeIndices == null) {
00064 int numNodes = nodeDomainIndices.length;
00065 nodeIndices = new int[numNodes];
00066 for (int i = 0; i < numNodes; i++) {
00067 nodeIndices[i] = i;
00068 }
00069 }
00070 this.nodeIndices = nodeIndices;
00071 this.nodeDomainIndices = nodeDomainIndices;
00072 assert nodeIndices.length == nodeDomainIndices.length;
00073 this.weight = weight;
00074 this.trials = trials;
00075 }
00076
00081 public WeightedSample(BeliefNetworkEx bn) {
00082 this(bn, new int[bn.bn.getNodes().length], 1.0, null, 0);
00083 }
00084
00085 public WeightedSample(BeliefNetworkEx bn, int[] domainIndices) {
00086 this(bn, domainIndices, 1.0, null, 0);
00087 }
00088
00098 public WeightedSample subSample(int[] queryNodes) {
00099
00100 int[] resultIndices = new int[queryNodes.length];
00101 for (int i = 0; i < queryNodes.length; i++) {
00102 resultIndices[i] = nodeDomainIndices[queryNodes[i]];
00103 }
00104 return new WeightedSample(bn, resultIndices, weight, queryNodes, 1);
00105 }
00106
00107
00108
00109
00110
00111
00112 @Override
00113 public int hashCode() {
00114 return Arrays.hashCode(nodeDomainIndices);
00115 }
00116
00117
00118
00119
00120
00121
00122 @Override
00123 public boolean equals(Object obj) {
00124 if (this == obj)
00125 return true;
00126 if (!(obj instanceof WeightedSample))
00127 return false;
00128 return Arrays.equals(nodeDomainIndices,
00129 (((WeightedSample) obj).nodeDomainIndices));
00130 }
00131
00138 public Map<String, String> getAssignmentMap() {
00139 Map<String, String> result = new HashMap<String, String>();
00140
00141 BeliefNode[] nodes = bn.bn.getNodes();
00142 for (int i = 0; i < nodeIndices.length; i++) {
00143 try {
00144 result.put(nodes[nodeIndices[i]].getName(),
00145 nodes[nodeIndices[i]].getDomain().getName(
00146 nodeDomainIndices[i]));
00147 } catch (RuntimeException e) {
00148 e.printStackTrace();
00149 throw e;
00150 }
00151 }
00152
00153 return result;
00154 }
00155
00162 public Map<String, String> getUndiscretizedAssignmentMap() {
00163 Map<String, String> result = new HashMap<String, String>();
00164
00165 BeliefNode[] nodes = bn.bn.getNodes();
00166 for (int i = 0; i < nodeIndices.length; i++) {
00167 try {
00168 Domain nodeDomain = nodes[nodeIndices[i]].getDomain();
00169 String value = nodeDomain.getName(nodeDomainIndices[i]);
00170 if (nodeDomain instanceof Discretized) {
00171 value = String.valueOf(((Discretized) nodeDomain)
00172 .getExampleValue(nodeDomainIndices[i]));
00173 }
00174 result.put(nodes[nodeIndices[i]].getName(), value);
00175 } catch (RuntimeException e) {
00176 e.printStackTrace();
00177 throw e;
00178 }
00179 }
00180 return result;
00181 }
00182
00183
00184
00185
00186
00187
00188 @Override
00189 public String toString() {
00190 return "WeightedSample(" + getAssignmentMap() + ", " + weight + ")";
00191 }
00192
00199 public String toShortString() {
00200 return "WeightedSample(" + Arrays.toString(nodeDomainIndices)
00201 + ", " + weight + ")";
00202 }
00203
00212 public boolean checkAssignment(String[][] queries) {
00213 int[] indices = bn.getNodeDomainIndicesFromStrings(queries);
00214 for (int nodeIndex : nodeIndices) {
00215 if (indices[nodeIndex] >= 0
00216 && indices[nodeIndex] != nodeDomainIndices[nodeIndex])
00217 return false;
00218 }
00219 return true;
00220 }
00221
00222 public String getCPDLookupString(BeliefNode node) {
00223 BeliefNode[] domain_product = node.getCPF().getDomainProduct();
00224 StringBuffer cond = new StringBuffer();
00225 for(int i = 0; i < domain_product.length; i++) {
00226 if(i > 0)
00227 cond.append(", ");
00228 cond.append(domain_product[i].getName()).append(" = ");
00229 cond.append(domain_product[i].getDomain().getName(nodeDomainIndices[this.bn.getNodeIndex(domain_product[i])]));
00230 }
00231 return cond.toString();
00232 }
00233 }