00001 package edu.tum.cs.bayesnets.learning;
00002
00003 import edu.ksu.cis.bnj.ver3.core.*;
00004 import edu.tum.cs.bayesnets.core.*;
00005 import edu.tum.cs.clustering.ClusterNamer;
00006
00007 import java.sql.*;
00008 import java.util.*;
00009
00010 import weka.core.*;
00011 import weka.clusterers.*;
00012
00013
00025 public class DomainLearner extends Learner {
00026
00031 ClusteredDomain[] clusteredDomains;
00032
00039 protected Instances[] clusterData;
00040
00041 protected Attribute attrValue;
00042
00049 SimpleKMeans[] clusterers;
00050
00054 protected ClusterNamer<SimpleKMeans> clusterNamer;
00055
00061 protected BeliefNode[] directDomains;
00062
00067 protected HashSet<?>[] directDomainData;
00068
00079 String[][] duplicateDomains;
00080
00081 protected boolean verbose = false;
00082
00088 static public class ClusteredDomain {
00089 public String nodeName;
00090
00091 public int numClusters;
00092
00100 public ClusteredDomain(String nodeName, int numClusters) {
00101 this.nodeName = nodeName;
00102 this.numClusters = numClusters;
00103 }
00104 }
00105
00135 public DomainLearner(BeliefNetworkEx bn, String[] directDomains,
00136 ClusteredDomain[] clusteredDomains, ClusterNamer namer,
00137 String[][] duplicateDomains) throws Exception {
00138 super(bn);
00139 init(getBeliefNodes(directDomains), clusteredDomains, namer, duplicateDomains);
00140 }
00141
00171 public DomainLearner(BeliefNetwork bn, String[] directDomains,
00172 ClusteredDomain[] clusteredDomains, ClusterNamer namer,
00173 String[][] duplicateDomains) {
00174 super(bn);
00175 init(getBeliefNodes(directDomains), clusteredDomains, namer, duplicateDomains);
00176 }
00177
00186 public DomainLearner(BeliefNetwork bn) {
00187 this(new BeliefNetworkEx(bn));
00188 }
00189
00198 public DomainLearner(BeliefNetworkEx bn) {
00199 super(bn);
00200 init(bn.bn.getNodes(), null, null, null);
00201 }
00202
00203 protected BeliefNode[] getBeliefNodes(String[] names) {
00204 BeliefNode[] nodes = new BeliefNode[names.length];
00205 for(int i = 0; i < names.length; i++)
00206 nodes[i] = this.bn.getNode(names[i]);
00207 return nodes;
00208 }
00209
00210 private void init(BeliefNode[] directDomains,
00211 ClusteredDomain[] clusteredDomains, ClusterNamer<SimpleKMeans> namer,
00212 String[][] duplicateDomains) {
00213 this.clusteredDomains = clusteredDomains;
00214 attrValue = new Attribute("value");
00215 if (clusteredDomains != null)
00216 clusterers = new SimpleKMeans[clusteredDomains.length];
00217 this.clusterNamer = namer;
00218 this.directDomains = directDomains;
00219 this.duplicateDomains = duplicateDomains;
00220
00221
00222 if (directDomains != null) {
00223 directDomainData = new HashSet<?>[directDomains.length];
00224 for (int i = 0; i < directDomains.length; i++)
00225 directDomainData[i] = new HashSet<String>();
00226 }
00227
00228
00229 if (clusteredDomains != null) {
00230 clusterData = new Instances[clusteredDomains.length];
00231 for (int i = 0; i < clusteredDomains.length; i++) {
00232 FastVector attribs = new FastVector(1);
00233 attribs.addElement(attrValue);
00234 clusterData[i] = new Instances(clusteredDomains[i].nodeName,
00235 attribs, 100);
00236 }
00237 }
00238 }
00239
00256 public void learn(ResultSet rs) throws Exception, SQLException {
00257
00258 if (!rs.next())
00259 throw new Exception("empty result set!");
00260
00261
00262 int numDirectDomains = directDomains != null ? directDomains.length : 0;
00263 int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length
00264 : 0;
00265 do {
00266
00267 for (int i = 0; i < numDirectDomains; i++) {
00268 ((HashSet<String>) directDomainData[i]).add(rs
00269 .getString(directDomains[i].getName()));
00270 }
00271
00272 for (int i = 0; i < numClusteredDomains; i++) {
00273 Instance inst = new Instance(1);
00274 inst.setValue(attrValue, rs
00275 .getDouble(clusteredDomains[i].nodeName));
00276 clusterData[i].add(inst);
00277 }
00278 } while (rs.next());
00279 }
00280
00297 public void learn(Instances instances) throws Exception, SQLException {
00298
00299 if(instances.numInstances() == 0)
00300 throw new Exception("empty result set!");
00301
00302
00303 int numDirectDomains = directDomains != null ? directDomains.length : 0;
00304 int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length : 0;
00305 Enumeration<Instance> instanceEnum = instances.enumerateInstances();
00306 while (instanceEnum.hasMoreElements()) {
00307 Instance instance = instanceEnum.nextElement();
00308
00309 for (int i = 0; i < numDirectDomains; i++) {
00310 ((HashSet<String>) directDomainData[i]).add(instance.stringValue(
00311 instances.attribute(directDomains[i].getName())));
00312 }
00313
00314 for (int i = 0; i < numClusteredDomains; i++) {
00315 Instance inst = new Instance(1);
00316 inst.setValue(attrValue, instance.value(
00317 instances.attribute(clusteredDomains[i].nodeName)));
00318 clusterData[i].add(inst);
00319 }
00320 }
00321 }
00322
00333 public void learn(Map<String, String> data) throws Exception {
00334 int numDirectDomains = directDomains != null ? directDomains.length : 0;
00335 int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length
00336 : 0;
00337
00338
00339 for (int i = 0; i < numDirectDomains; i++) {
00340 String val = data.get(directDomains[i]);
00341 if (val == null)
00342 throw new Exception("Key " + clusteredDomains[i].nodeName
00343 + " not found in data!");
00344 ((HashSet<String>) directDomainData[i]).add(val);
00345 }
00346
00347 for (int i = 0; i < numClusteredDomains; i++) {
00348 Instance inst = new Instance(1);
00349 String val = data.get(clusteredDomains[i].nodeName);
00350 if (val == null) {
00351 boolean b = data.containsKey(clusteredDomains[i].nodeName);
00352 throw new Exception("Key " + clusteredDomains[i].nodeName
00353 + " not found in data!");
00354 }
00355 inst.setValue(attrValue, Double.parseDouble(val));
00356 clusterData[i].add(inst);
00357 }
00358 }
00359
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00419 protected void end_learning() throws Exception {
00420 if (directDomains != null)
00421 for (int i = 0; i < directDomains.length; i++) {
00422 if (verbose)
00423 System.out.println(directDomains[i]);
00424 HashSet<String> hs = (HashSet<String>) directDomainData[i];
00425 Discrete domain = new Discrete();
00426 for (Iterator<String> iter = hs.iterator(); iter.hasNext();)
00427 domain.addName(iter.next());
00428 BeliefNode node = directDomains[i];
00429 if (node == null) {
00430 System.out.println("No node with name '" + directDomains[i]
00431 + "' found to learn direct domain for.");
00432 }
00433
00434 bn.bn.changeBeliefNodeDomain(node, domain);
00435 }
00436 if (clusteredDomains != null)
00437 for (int i = 0; i < clusteredDomains.length; i++) {
00438 if (verbose)
00439 System.out.println(clusteredDomains[i].nodeName);
00440 try {
00441
00442 clusterers[i] = new SimpleKMeans();
00443 if (clusteredDomains[i].numClusters != 0)
00444 clusterers[i]
00445 .setNumClusters(clusteredDomains[i].numClusters);
00446 clusterers[i].buildClusterer(clusterData[i]);
00447
00448 bn.bn.changeBeliefNodeDomain(bn
00449 .getNode(clusteredDomains[i].nodeName),
00450 new Discretized(clusterers[i], clusterNamer));
00451 } catch (Exception e) {
00452 e.printStackTrace();
00453 }
00454 }
00455 if (duplicateDomains != null) {
00456 for (int i = 0; i < duplicateDomains.length; i++) {
00457 Domain srcDomain = bn.getDomain(duplicateDomains[i][0]);
00458 for (int j = 1; j < duplicateDomains[i].length; j++) {
00459 if (verbose)
00460 System.out.println(duplicateDomains[i][j]);
00461 bn.bn.changeBeliefNodeDomain(bn
00462 .getNode(duplicateDomains[i][j]), srcDomain);
00463 }
00464 }
00465 }
00466 }
00467
00476 public SimpleKMeans[] getClusterers() throws Exception {
00477 finish();
00478 return clusterers;
00479 }
00480
00489 public void sortClusteredDomains() {
00490
00491 for (int i = 0; i < clusteredDomains.length; i++) {
00492 BeliefNode node = bn.getNode(clusteredDomains[i].nodeName);
00493 sortClusteredDomain(node, clusterers[i]);
00494 }
00495
00496 if (duplicateDomains != null) {
00497 for (int i = 0; i < duplicateDomains.length; i++)
00498 for (int j = 0; j < clusteredDomains.length; j++)
00499 if (duplicateDomains[i][0]
00500 .equals(clusteredDomains[j].nodeName)) {
00501 for (int k = 1; k < duplicateDomains[i].length; k++)
00502 sortClusteredDomain(bn
00503 .getNode(duplicateDomains[i][k]),
00504 clusterers[j]);
00505 break;
00506 }
00507 }
00508 }
00509
00517 protected void sortClusteredDomain(BeliefNode node, SimpleKMeans clusterer) {
00518
00519
00520
00521 int numClusters = clusterer.getNumClusters();
00522 double[] values = clusterer.getClusterCentroids()
00523 .attributeToDoubleArray(0);
00524 double[] sorted_values = (double[]) values.clone();
00525 Arrays.sort(sorted_values);
00526
00527 Discrete domain = (Discrete) node.getDomain();
00528 Discrete sorted_domain = new Discrete();
00529 for (int new_idx = 0; new_idx < numClusters; new_idx++) {
00530 for (int old_idx = 0; old_idx < numClusters; old_idx++)
00531 if (values[old_idx] == sorted_values[new_idx])
00532 sorted_domain.addName(domain.getName(old_idx));
00533 }
00534
00535 bn.bn.changeBeliefNodeDomain(node, sorted_domain);
00536 }
00537 }