00001 package edu.tum.cs.bayesnets.learning;
00002
00003
00004 import edu.ksu.cis.bnj.ver3.core.*;
00005 import edu.ksu.cis.bnj.ver3.core.values.Field;
00006 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00007 import java.sql.*;
00008 import java.util.*;
00009
00010 import org.apache.log4j.Level;
00011 import org.apache.log4j.Logger;
00012
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.core.Discretized;
00015
00016 import weka.clusterers.*;
00017 import weka.core.*;
00018
00026 public class CPTLearner extends Learner {
00030 static final Logger logger = Logger.getLogger(CPTLearner.class);
00031 static {
00032 logger.setLevel(Level.WARN);
00033 }
00034
00038 protected ExampleCounter[] counters;
00043 protected Clusterer[] clusterers;
00048 protected boolean uniformDefault;
00049
00054 public CPTLearner(BeliefNetworkEx bn) {
00055 super(bn);
00056 init();
00057 }
00058
00063 public void setUniformDefault(boolean value) {
00064 uniformDefault = value;
00065 }
00066
00076 public CPTLearner(DomainLearner dl) throws Exception {
00077 super(dl.bn.bn);
00078 init();
00079
00080 if(dl.clusteredDomains != null) {
00081 for(int i = 0; i < dl.clusteredDomains.length; i++)
00082 addClusterer(dl.clusteredDomains[i].nodeName, dl.clusterers[i]);
00083 if(dl.duplicateDomains != null) {
00084 for(int i = 0; i < dl.duplicateDomains.length; i++)
00085 for(int j = 0; j < dl.clusteredDomains.length; j++)
00086 if(dl.duplicateDomains[i][0].equals(dl.clusteredDomains[j].nodeName)) {
00087 for(int k = 1; k < dl.duplicateDomains[i].length; k++)
00088 addClusterer(dl.duplicateDomains[i][k], dl.clusterers[j]);
00089 break;
00090 }
00091 }
00092 }
00093 }
00094
00099 private void init() {
00100 uniformDefault = false;
00101 clusterers = new Clusterer[nodes.length];
00102
00103 counters = new ExampleCounter[nodes.length];
00104 for(int i = 0; i < nodes.length; i++)
00105 counters[i] = new ExampleCounter(nodes[i], bn);
00106 }
00107
00117 public void learn(ResultSet rs) throws Exception {
00118 try {
00119
00120 if(!rs.next())
00121 throw new Exception("empty result set!");
00122
00123 BeliefNode[] nodes = bn.bn.getNodes();
00124 ResultSetMetaData rsmd = rs.getMetaData();
00125 int numCols = rsmd.getColumnCount();
00126
00127
00128
00129
00130
00131 int[] nodeIdx2colIdx = new int[nodes.length];
00132 Arrays.fill(nodeIdx2colIdx, -1);
00133 for(int i = 1; i <= numCols; i++) {
00134 Set<String> nodeNames = bn.getNodeNamesForAttribute(rsmd.getColumnName(i));
00135 for (String nodeName: nodeNames) {
00136 int node_idx = bn.getNodeIndex(nodeName);
00137 if(node_idx == -1)
00138 throw new Exception("Unknown node referenced in result set: " + rsmd.getColumnName(i));
00139 nodeIdx2colIdx[node_idx] = i;
00140 }
00141 }
00142
00143
00144 int[] domainIndices = new int[nodes.length];
00145 do {
00146
00147
00148
00149
00150
00151 for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
00152 int domain_idx;
00153 if(clusterers[node_idx] == null) {
00154 Discrete domain = (Discrete) nodes[node_idx].getDomain();
00155
00156 String strValue;
00157 if (domain instanceof Discretized) {
00158 double value = rs.getDouble(nodeIdx2colIdx[node_idx]);
00159 strValue = (((Discretized)domain).getNameFromContinuous(value));
00160 } else {
00161 strValue = rs.getString(nodeIdx2colIdx[node_idx]);
00162 }
00163 domain_idx = domain.findName(strValue);
00164 if(domain_idx == -1)
00165 throw new Exception(strValue + " not found in domain of " + nodes[node_idx].getName());
00166 }
00167 else {
00168 Instance inst = new Instance(1);
00169 double value = rs.getDouble(bn.getAttributeNameForNode(bn.bn.getNodes()[node_idx].getName()));
00170 inst.setValue(0, value);
00171 domain_idx = clusterers[node_idx].clusterInstance(inst);
00172 }
00173 domainIndices[node_idx] = domain_idx;
00174 }
00175
00176 for(int i = 0; i < nodes.length; i++) {
00177 counters[i].count(domainIndices);
00178 }
00179 } while(rs.next());
00180 }
00181 catch (SQLException ex) {
00182 System.out.println("SQLException: " + ex.getMessage());
00183 System.out.println("SQLState: " + ex.getSQLState());
00184 System.out.println("VendorError: " + ex.getErrorCode());
00185 }
00186 }
00187
00197 public void learn(Instances instances) throws Exception {
00198
00199 if(instances.numInstances() == 0)
00200 throw new Exception("empty result set!");
00201
00202 BeliefNode[] nodes = bn.bn.getNodes();
00203 int numAttributes = instances.numAttributes();
00204
00205
00206
00207
00208
00209 int[] nodeIdx2colIdx = new int[nodes.length];
00210 Arrays.fill(nodeIdx2colIdx, -1);
00211 for(int i = 0; i < numAttributes; i++) {
00212 Set<String> nodeNames = bn.getNodeNamesForAttribute(instances.attribute(i).name());
00213 logger.debug("Nodes for attribute "+instances.attribute(i).name()+": "+nodeNames);
00214 if (nodeNames==null)
00215 continue;
00216 for (String nodeName: nodeNames) {
00217 int node_idx = bn.getNodeIndex(nodeName);
00218 if(node_idx == -1)
00219 throw new Exception("Unknown node referenced in result set: " + instances.attribute(i).name());
00220 nodeIdx2colIdx[node_idx] = i;
00221 }
00222 }
00223
00224
00225 int[] domainIndices = new int[nodes.length];
00226 Enumeration<Instance> instanceEnum = instances.enumerateInstances();
00227 while (instanceEnum.hasMoreElements()) {
00228 Instance instance = instanceEnum.nextElement();
00229
00230
00231
00232
00233
00234 for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
00235 int domain_idx;
00236 if(clusterers[node_idx] == null) {
00237 Discrete domain = (Discrete) nodes[node_idx].getDomain();
00238 String strValue;
00239 if (domain instanceof Discretized) {
00240 int colIdx = nodeIdx2colIdx[node_idx];
00241 if (colIdx < 0) {
00242 bn.dump();
00243 for (int i = 0; i < numAttributes; i++) {
00244 logger.debug("Attribute "+i+": "+instances.attribute(i).name());
00245 }
00246 StringBuffer sb = new StringBuffer();
00247 for (int i = 0; i < nodeIdx2colIdx.length; i++) {
00248 sb.append(i+"\t");
00249 }
00250 sb.append("\n");
00251 for (int i = 0; i < nodeIdx2colIdx.length; i++) {
00252 sb.append(nodeIdx2colIdx[i]+"\t");
00253 }
00254 logger.debug(sb);
00255 throw new Exception("No attribute specified for "+bn.bn.getNodes()[node_idx].getName());
00256 }
00257 double value = instance.value(colIdx);
00258 strValue = (((Discretized)domain).getNameFromContinuous(value));
00259 if (domain.findName(strValue) == -1) {
00260 logger.debug(domain);
00261 logger.debug(strValue);
00262 }
00263 } else {
00264 int colIdx = nodeIdx2colIdx[node_idx];
00265 if (colIdx < 0) {
00266 throw new Exception("No attribute specified for "+bn.bn.getNodes()[node_idx].getName());
00267 }
00268 strValue = instance.stringValue(nodeIdx2colIdx[node_idx]);
00269 }
00270 domain_idx = domain.findName(strValue);
00271 if(domain_idx == -1) {
00272 String[] myDomain = bn.getDiscreteDomainAsArray(bn.bn.getNodes()[node_idx].getName());
00273 for (int i=0; i<myDomain.length; i++) {
00274 logger.debug(myDomain[i]);
00275 }
00276 throw new Exception(strValue + " not found in domain of " + nodes[node_idx].getName());
00277 }
00278 }
00279 else {
00280 Instance inst = new Instance(1);
00281 inst.setValue(0, instance.value(nodeIdx2colIdx[node_idx]));
00282 domain_idx = clusterers[node_idx].clusterInstance(inst);
00283 }
00284 domainIndices[node_idx] = domain_idx;
00285 }
00286
00287 for(int i = 0; i < nodes.length; i++) {
00288 counters[i].count(domainIndices);
00289 }
00290 }
00291 }
00292
00301 public void learn(Map<String,String> data) throws Exception {
00302
00303
00304
00305
00306 BeliefNode[] nodes = bn.bn.getNodes();
00307 int[] domainIndices = new int[nodes.length];
00308 for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
00309 int domain_idx;
00310 String value = data.get(nodes[node_idx].getName());
00311 if(value == null)
00312 throw new Exception("Key " + nodes[node_idx].getName() + " not found in data!");
00313 if(clusterers[node_idx] == null) {
00314 Discrete domain = (Discrete) nodes[node_idx].getDomain();
00315 domain_idx = domain.findName(value);
00316 if(domain_idx == -1)
00317 throw new Exception(value + " not found in domain of " + nodes[node_idx].getName());
00318 }
00319 else {
00320 Instance inst = new Instance(1);
00321 inst.setValue(0, Double.parseDouble(value));
00322 domain_idx = clusterers[node_idx].clusterInstance(inst);
00323 }
00324 domainIndices[node_idx] = domain_idx;
00325 }
00326
00327 for(int i = 0; i < nodes.length; i++) {
00328 counters[i].count(domainIndices);
00329 }
00330 }
00331
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00397 public void addClusterer(String nodeName, Clusterer clusterer) throws Exception {
00398 for(int i = 0; i < nodes.length; i++)
00399 if(nodes[i].getName().equals(nodeName)) {
00400 clusterers[i] = clusterer;
00401 return;
00402 }
00403 throw new Exception("Passed unknown node name!");
00404 }
00405
00409 protected void end_learning() {
00410
00411 for(int i = 0; i < nodes.length; i++)
00412 nodes[i].getCPF().normalizeByDomain(uniformDefault);
00413 }
00414
00415
00419 protected class ExampleCounter {
00420 CPF cpf;
00424 public int[] nodeIndices;
00425
00431 public ExampleCounter(BeliefNode n, BeliefNetworkEx bn) {
00432
00433 cpf = n.getCPF();
00434 for(int i = 0; i < cpf.size(); i++)
00435 cpf.put(i, new ValueDouble(0));
00436
00437
00438 BeliefNode[] nodes = cpf.getDomainProduct();
00439 nodeIndices = new int[nodes.length];
00440 for(int i = 0; i < nodes.length; i++)
00441 nodeIndices[i] = bn.getNodeIndex(nodes[i]);
00442 }
00443
00444 public ExampleCounter(CPF cpf, int[] nodeIndices) {
00445 this.cpf = cpf;
00446 this.nodeIndices = nodeIndices;
00447 }
00448
00457 public void count(int[] domainIndices) {
00458 count(domainIndices, 1.0);
00459 }
00460
00470 public void count(int[] domainIndices, double weight) {
00471 int[] addr = new int[nodeIndices.length];
00472
00473
00474 for(int i = 0; i < nodeIndices.length; i++) {
00475 addr[i] = domainIndices[nodeIndices[i]];
00476 }
00477
00478
00479 int realAddr = cpf.addr2realaddr(addr);
00480
00481 cpf.put(realAddr, Field.add(cpf.get(realAddr), new ValueDouble(weight)) );
00482 }
00483 }
00484 }