00001
00002
00003 package edu.tum.cs.bayesnets.core.io;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNetwork;
00006 import edu.ksu.cis.bnj.ver3.streams.*;
00007 import java.io.*;
00008 import java.util.*;
00009 import java.util.Map.Entry;
00010
00011 import javax.xml.parsers.DocumentBuilder;
00012 import javax.xml.parsers.DocumentBuilderFactory;
00013 import org.w3c.dom.*;
00014
00023 public class Converter_pmml
00024 implements OmniFormatV1, Exporter, Importer
00025 {
00026 protected OmniFormatV1 _Writer;
00027 protected int bn_cnt;
00028 private int bnode_cnt;
00029
00030
00031 protected HashMap<Integer, NodeData> nodeData;
00032 protected Writer w;
00033 public int netDepth;
00034 protected int curNodeIdx;
00035
00036
00037
00038 protected HashMap<Integer, String> nodeNames;
00039
00040 protected HashMap<Integer, Integer> nodeIndices;
00041 protected NodeData curNode;
00042 HashMap<Integer, Node> cptTags;
00043
00044
00045 protected StringBuffer cpf;
00046 protected int cpfNodeID;
00047
00048
00049 public Converter_pmml()
00050 {
00051 w = null;
00052 curNodeIdx = 0;
00053 }
00054
00055 public OmniFormatV1 getStream1()
00056 {
00057 return this;
00058 }
00059
00060
00061
00062
00063
00064 public void load(InputStream stream, OmniFormatV1 writer)
00065 {
00066 _Writer = writer;
00067 _Writer.Start();
00068 bn_cnt = 0;
00069 bnode_cnt = 0;
00070 nodeIndices = new HashMap<Integer, Integer>();
00071 DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
00072
00073 factory.setNamespaceAware(true);
00074 org.w3c.dom.Document doc;
00075 try
00076 {
00077 DocumentBuilder parser = factory.newDocumentBuilder();
00078 doc = parser.parse(stream);
00079 }
00080 catch(Exception e)
00081 {
00082 throw new RuntimeException(e);
00083 }
00084 visitDocument(doc);
00085 System.gc();
00086 }
00087
00088 public void visitDocument(Node parent)
00089 {
00090 NodeList l = parent.getChildNodes();
00091 if(l == null)
00092 throw new RuntimeException("Unexpected end of document!");
00093 int max = l.getLength();
00094 for(int i = 0; i < max; i++) {
00095 Node node = l.item(i);
00096 switch(node.getNodeType())
00097 {
00098 case 1:
00099 String name = node.getNodeName();
00100 if(name.equals("PMML"))
00101 {
00102
00103 NamedNodeMap attrs = node.getAttributes();
00104 if(attrs != null) {
00105 int amax = attrs.getLength();
00106 for(int j = 0; j < amax; j++) {
00107 Node attr = attrs.item(j);
00108 String aname = attr.getNodeName().toUpperCase();
00109 if(aname.equals("version"))
00110 try {
00111 if(!aname.equals("3.0"))
00112 throw new RuntimeException("PMML version " + aname + " is not supported");
00113 }
00114 catch(Exception e) { }
00115
00116
00117 }
00118 }
00119
00120
00121 cptTags = new HashMap<Integer, Node>();
00122 visitDocument(node);
00123
00124
00125
00126 for(Entry<Integer,Node> e : cptTags.entrySet()) {
00127 visitDefinition(e.getValue(), e.getKey());
00128 }
00129 cptTags = null;
00130 }
00131 else if(name.equals("DataDictionary")) {
00132 _Writer.CreateBeliefNetwork(bn_cnt);
00133 visitDataDict(node);
00134 bn_cnt++;
00135 }
00136
00137
00138
00139 }
00140 }
00141
00142 }
00143
00144 public void visitDataDict(Node parent)
00145 {
00146
00147 NodeList l = parent.getChildNodes();
00148 if(l == null)
00149 throw new RuntimeException("Unexpected end of document!");
00150 int max = l.getLength();
00151 for(int i = 0; i < max; i++) {
00152 Node node = l.item(i);
00153 switch(node.getNodeType()) {
00154 case 1:
00155 String name = node.getNodeName();
00156 if(name.equals("DataField")) {
00157 visitDataField(node);
00158 bnode_cnt++;
00159 }
00160 }
00161 }
00162 }
00163
00168 protected void visitDataField(Node parent)
00169 {
00170
00171 NamedNodeMap attrs = parent.getAttributes();
00172 String nodeName = null;
00173 Integer nodeID = null;
00174 int max;
00175 if(attrs != null) {
00176 max = attrs.getLength();
00177 for(int i = 0; i < max; i++) {
00178 Node attr = attrs.item(i);
00179 String attrName = attr.getNodeName();
00180 String value = attr.getNodeValue();
00181 if(attrName.equals("name")) {
00182 nodeName = value;
00183
00184 }
00185 if(attrName.equals("id")) {
00186 nodeID = Integer.parseInt(value);
00187 }
00188
00189
00190
00191 }
00192 }
00193
00194 if(nodeName == null || nodeID == null)
00195 throw new RuntimeException("Missing DataField attribute 'name' or 'id'!");
00196
00197 nodeIndices.put(nodeID, new Integer(bnode_cnt));
00198
00199 _Writer.BeginBeliefNode(bnode_cnt);
00200 _Writer.SetBeliefNodeName(nodeName);
00201
00202
00203
00204 NodeList l = parent.getChildNodes();
00205 max = l.getLength();
00206 for(int i = 0; i < max; i++) {
00207 Node node = l.item(i);
00208 switch(node.getNodeType()) {
00209 case 1:
00210 String name = node.getNodeName();
00211 if(name.equals("Value")) {
00212 attrs = node.getAttributes();
00213 for(int j = attrs.getLength()-1; j >= 0; j--) {
00214 Node attr = attrs.item(j);
00215 if(attr.getNodeName().equals("value"))
00216 _Writer.BeliefNodeOutcome(attr.getNodeValue());
00217 }
00218 }
00219 else if(name.equals("Extension")) {
00220 NodeList l_ext = node.getChildNodes();
00221 for(int j = 0; j < l_ext.getLength(); j++) {
00222 Node n = l_ext.item(j);
00223 if(n.getNodeName().equals("X-NodeType")) {
00224 _Writer.SetType(getElementValue(n));
00225 }
00226 else if(n.getNodeName().equals("X-Position")) {
00227 attrs = n.getAttributes();
00228 int xPos = 0, yPos = 0, have = 0;
00229 for(int k = attrs.getLength()-1; k >= 0; k--) {
00230 Node attr = attrs.item(k);
00231 if(attr.getNodeName().equals("x"))
00232 xPos = Integer.parseInt(attr.getNodeValue());
00233 else if(attr.getNodeName().equals("y"))
00234 yPos = Integer.parseInt(attr.getNodeValue());
00235 }
00236 _Writer.SetBeliefNodePosition(xPos, yPos);
00237 }
00238 else if(n.getNodeName().equals("X-Definition"))
00239 cptTags.put(nodeID, n);
00240 }
00241 }
00242 break;
00243 }
00244 }
00245
00246 _Writer.EndBeliefNode();
00247 }
00248
00249 protected void visitDefinition(Node definition, int nodeID)
00250 {
00251 NodeList l = definition.getChildNodes();
00252 if(l == null)
00253 return;
00254 LinkedList<Integer> parents = new LinkedList<Integer>();
00255 int curNode = nodeIndices.get(nodeID);
00256 String CPTString = "";
00257 int max = l.getLength();
00258 for(int i = 0; i < max; i++)
00259 {
00260 Node node = l.item(i);
00261 switch(node.getNodeType())
00262 {
00263 case 1:
00264 String name = node.getNodeName();
00265 if(name.equals("X-Given")) {
00266 parents.add(nodeIndices.get(Integer.parseInt(getElementValue(node))));
00267 }
00268 else
00269 if(name.equals("X-Table"))
00270 CPTString = getElementValue(node);
00271 }
00272 }
00273
00274 if(curNode >= 0)
00275 {
00276 for(Integer p : parents) {
00277 _Writer.Connect(p, curNode);
00278 }
00279
00280 _Writer.BeginCPF(curNode);
00281 StringTokenizer tok = new StringTokenizer(CPTString);
00282 int maxz = tok.countTokens();
00283 for(int c = 0; c < maxz; c++)
00284 {
00285 String SSS = tok.nextToken();
00286 _Writer.ForwardFlat_CPFWriteValue(SSS);
00287 }
00288
00289 _Writer.EndCPF();
00290 }
00291 }
00292
00293 protected String getElementValue(Node parent)
00294 {
00295 NodeList l = parent.getChildNodes();
00296 if(l == null)
00297 return null;
00298 StringBuffer buf = new StringBuffer();
00299 int max = l.getLength();
00300 for(int i = 0; i < max; i++)
00301 {
00302 Node node = l.item(i);
00303 switch(node.getNodeType())
00304 {
00305 case 3:
00306 buf.append(node.getNodeValue());
00307 break;
00308
00309 default:
00310 System.out.println("Unhandled node " + node.getNodeName());
00311 break;
00312
00313 case 1:
00314 case 8:
00315 break;
00316 }
00317 }
00318
00319 return buf.toString().trim();
00320 }
00321
00322
00323
00324
00325
00326 protected class NodeData {
00327 public String cpfData, subElements, nodeType, opType, name, domainClassName;
00328 int index;
00329 int xPos, yPos;
00330 Vector<Integer> parents;
00331 public NodeData() {
00332 cpfData = new String();
00333 subElements = new String();
00334 parents = new Vector<Integer>();
00335 }
00336 }
00337
00338 public void save(BeliefNetwork bn, OutputStream os) {
00339 w = new OutputStreamWriter(os);
00340 OmniFormatV1_Writer.Write(bn, this);
00341 }
00342
00343 public void fwrite(String x)
00344 {
00345 try
00346 {
00347 w.write(x);
00348 w.flush();
00349 }
00350 catch(Exception e)
00351 {
00352 System.out.println("unable to write?");
00353 }
00354 }
00355
00356 public void Start()
00357 {
00358 netDepth = 0;
00359 nodeNames = new HashMap<Integer, String>();
00360
00361 fwrite("<?xml version=\"1.0\" encoding=\"US-ASCII\"?>\n");
00362 fwrite("<!-- Bayesian network in a PMML-based format -->\n");
00363 fwrite("<PMML version=\"3.0\" xmlns=\"http://www.dmg.org/PMML-3_0\">\n");
00364 fwrite("\t<Header copyright=\"Technische Universitaet Muenchen\" />\n");
00365 }
00366
00367 public void CreateBeliefNetwork(int idx)
00368 {
00369 if(netDepth > 0)
00370 {
00371 netDepth = 0;
00372 fwrite("\t</DataDictionary>\n");
00373 }
00374 nodeData = new HashMap<Integer,NodeData>();
00375 fwrite("\t<DataDictionary>\n");
00376 netDepth = 1;
00377 }
00378
00379 public void SetBeliefNetworkName(int idx, String name)
00380 {
00381
00382 }
00383
00384 public void BeginBeliefNode(int idx) {
00385 curNode = new NodeData();
00386 curNode.index = idx;
00387 curNodeIdx = idx;
00388
00389 }
00390
00391 public void SetType(String type)
00392 {
00393 curNode.nodeType = type;
00394 if(type.equals("utility"))
00395 curNode.opType = "continuous";
00396 else
00397 curNode.opType = "categorical";
00398 }
00399
00400 public void SetBeliefNodePosition(int x, int y) {
00401 curNode.xPos = x;
00402 curNode.yPos = y;
00403 }
00404
00405 public void SetBeliefNodeDomainClass(String domainClassName) {
00406 curNode.domainClassName = domainClassName;
00407 }
00408
00409 public void BeliefNodeOutcome(String outcome) {
00410 curNode.subElements += "\t\t\t<Value value=\"" + outcome.replaceAll("<", "<").replaceAll(">", ">") + "\" />\n";
00411 }
00412
00413 public void SetBeliefNodeName(String name) {
00414
00415 curNode.name = name;
00416 nodeNames.put(new Integer(curNodeIdx), name);
00417 }
00418
00419 public void MakeContinuous(String s) {
00420 }
00421
00422 public void EndBeliefNode() {
00423 nodeData.put(curNode.index, curNode);
00424 }
00425
00426 public void Connect(int par_idx, int chi_idx) {
00427 nodeData.get(chi_idx).parents.add(par_idx);
00428 }
00429
00430 public void BeginCPF(int idx) {
00431
00432 cpfNodeID = idx;
00433 cpf = new StringBuffer("\t\t\t\t<X-Definition>\n");
00434 String gname;
00435 for(Integer given : nodeData.get(idx).parents)
00436 {
00437
00438 gname = (String)nodeNames.get(given);
00439 cpf.append("\t\t\t\t\t<X-Given>" + given + "</X-Given> <!-- " + gname + " -->\n");
00440 }
00441 cpf.append("\t\t\t\t\t<X-Table>");
00442 }
00443
00444 public void ForwardFlat_CPFWriteValue(String x)
00445 {
00446 cpf.append(x + " ");
00447 }
00448
00449 public void EndCPF()
00450 {
00451 cpf.append("</X-Table>\n");
00452 cpf.append("\t\t\t\t</X-Definition>\n");
00453 NodeData d = (NodeData)nodeData.get(cpfNodeID);
00454 d.cpfData = cpf.toString();
00455
00456 }
00457
00458 public int GetCPFSize()
00459 {
00460 return 0;
00461 }
00462
00463 public void Finish()
00464 {
00465 if(netDepth > 0)
00466 {
00467
00468 Iterator<NodeData> i = nodeData.values().iterator();
00469 while(i.hasNext()) {
00470 NodeData nd = i.next();
00471 fwrite("\t\t<DataField name=\"" + nd.name + "\" optype=\"" + nd.opType + "\" id=\"" + nd.index + "\">\n");
00472 fwrite("\t\t\t<Extension>\n");
00473 fwrite("\t\t\t\t<X-NodeType>" + nd.nodeType + "</X-NodeType>\n");
00474 if (nd.domainClassName != null)
00475 fwrite("\t\t\t\t<X-NodeDomainClass>" + nd.domainClassName + "</X-NodeDomainClass>\n");
00476 fwrite("\t\t\t\t<X-Position x=\"" + nd.xPos + "\" y=\"" + nd.yPos + "\" />\n");
00477 fwrite(nd.cpfData);
00478 fwrite("\t\t\t</Extension>\n");
00479 fwrite(nd.subElements);
00480 fwrite("\t\t</DataField>\n");
00481 }
00482
00483 netDepth = 0;
00484 fwrite("\t</DataDictionary>\n");
00485 }
00486 fwrite("</PMML>\n");
00487 try
00488 {
00489 w.close();
00490 }
00491 catch(Exception exception) { }
00492 }
00493
00494
00495
00496 public String getExt() {
00497 return "*.pmml";
00498 }
00499
00500 public String getDesc() {
00501 return "PMML 3.0";
00502 }
00503 }