00001 package edu.tum.cs.probcog.prolog;
00002
00003 import java.io.IOException;
00004 import java.util.Arrays;
00005 import java.util.HashMap;
00006 import java.util.HashSet;
00007 import java.util.Hashtable;
00008 import java.util.Iterator;
00009 import java.util.Map;
00010 import java.util.Set;
00011 import java.util.Vector;
00012
00013 import jpl.Query;
00014 import jpl.fli.Prolog;
00015 import edu.tum.cs.logic.parser.ParseException;
00016 import edu.tum.cs.probcog.InferenceResult;
00017 import edu.tum.cs.probcog.Model;
00018 import edu.tum.cs.probcog.Server;
00019
00027 public class PrologInterface {
00028
00029 public static final String UNKNOWN_TYPE = "TYPE_UNKNOWN";
00030
00035 private static Map<String, String> objectTypes = new HashMap<String, String>();
00036
00037 private static String modelPool = "/tmp/buildd/ros-diamondback-knowrob-0.2.0/debian/ros-diamondback-knowrob/opt/ros/diamondback/stacks/knowrob/srldb/models/models.xml";
00038
00039 private static String modelName = "tableSetting_fall10";
00040
00041 private static Server server = null;
00042
00043
00044 private static void initProlog() {
00045 try {
00046 Vector<String> args= new Vector<String>(Arrays.asList(Prolog.get_default_init_args()));
00047 args.add( "-G256M" );
00048
00049 args.add( "-nosignals" );
00050 Prolog.set_default_init_args( args.toArray( new String[0] ) );
00051
00052
00053 new Query("ensure_loaded('/home/tenorth/work/owl/gram_ias_human.pl')").oneSolution();
00054
00055 } catch(Exception e) {
00056 e.printStackTrace();
00057 }
00058 }
00059
00063 public static void reset() {
00064 objectTypes.clear();
00065 server = null;
00066 }
00067
00076 public static Map<String, Vector<Object>> executeQuery(String query,
00077 String plFile) {
00078
00079 System.err.println("Executing query: " + query);
00080
00081 HashMap<String, Vector<Object>> result = new HashMap<String, Vector<Object>>();
00082 Hashtable<?, ?>[] solutions;
00083
00084 Query q = new Query("expand_goal((" + query + "),_9), call(_9)");
00085
00086
00087 if (!q.hasMoreSolutions())
00088 return new HashMap<String, Vector<Object>>();
00089 Hashtable<?, ?> oneSolution = q.nextSolution();
00090 if (oneSolution.isEmpty())
00091
00092 return new HashMap<String, Vector<Object>>();
00093
00094
00095
00096
00097 q.rewind();
00098 solutions = q.allSolutions();
00099
00100 for (Object key : solutions[0].keySet()) {
00101 result.put(key.toString(), new Vector<Object>());
00102 }
00103
00104
00105 for (int i = 0; i < solutions.length; i++) {
00106 Hashtable<?, ?> solution = solutions[i];
00107 for (Object key : solution.keySet()) {
00108 String keyStr = key.toString();
00109
00110 if (!result.containsKey(keyStr)) {
00111
00112
00113 Vector<Object> resultVector = new Vector<Object>();
00114 resultVector.add(i, solution.get(key).toString());
00115 result.put(keyStr, resultVector);
00116
00117 }
00118
00119 Vector<Object> resultVector = result.get(keyStr);
00120 resultVector.add(i, solution.get(key).toString());
00121 }
00122 }
00123
00124 return result;
00125 }
00126
00131 public static void setObjectsOnTable(String[] objs) {
00132
00133 for (String identifier : objs)
00134 objectTypes.put(identifier, UNKNOWN_TYPE);
00135 }
00136
00141 public static void queryObjectTypes() {
00142
00143 for (Iterator<String> i = objectTypes.keySet().iterator(); i.hasNext();) {
00144 String instance = i.next();
00145 String type = inferObjectType(instance);
00146 if (objectTypes.get(instance).equals(UNKNOWN_TYPE)) {
00147 String localClassName = getLocalClassName(type);
00148 objectTypes.put(instance, localClassName);
00149 }
00150 }
00151
00152 }
00153
00163 public static String getLocalClassName(String uri) {
00164
00165 Map<String, Vector<Object>> answer = executeQuery(
00166 "rdf_split_url(Base, Local, '" + uri + "')", "");
00167
00168 for (Iterator<String> i = answer.keySet().iterator(); i.hasNext();) {
00169 String key = i.next();
00170 if (key.equals("Local") && answer.get(key).size() > 0)
00171 return ((String) answer.get(key).get(0)).replaceAll("'", "");
00172 }
00173
00174 throw new RuntimeException(
00175 "ERROR: Could not determine local class name.");
00176 }
00177
00185 public static String inferObjectType(String instanceName) {
00186
00187 Map<String, Vector<Object>> answer = executeQuery("rdf_has('"
00188 + instanceName + "', rdf:type, Type)", "");
00189
00190 for (Iterator<String> i = answer.keySet().iterator(); i.hasNext();) {
00191 String key = i.next();
00192 if (key.equals("Type") && answer.get(key).size() > 0)
00193 return ((String) answer.get(key).get(0)).replaceAll("'", "");
00194 }
00195
00196 throw new RuntimeException("ERROR: Cannot infer type of "
00197 + instanceName);
00198 }
00199
00207 public static String inferObjectClass(String instanceName) {
00208
00209 Map<String, Vector<Object>> answer = executeQuery("rdf_has('"
00210 + instanceName + "', rdf:type, Type)", "");
00211
00212 for (Iterator<String> i = answer.keySet().iterator(); i.hasNext();) {
00213 String key = i.next();
00214 if (key.equals("Type") && answer.get(key).size() > 0)
00215 return ((String) answer.get(key).get(0)).replaceAll("'", "");
00216 }
00217
00218 return null;
00219 }
00220
00221 public static void setPerception(String owlFile) {
00222 executeQuery("owl_parser:owl_parse('" + owlFile
00223 + "', false, false, true)", "");
00224 }
00225
00226 public static void setModelName(String modelName) {
00227 PrologInterface.modelName = modelName;
00228 }
00229
00230 public static void setModelPool(String modelPool) {
00231 PrologInterface.modelPool = modelPool;
00232 }
00233
00234 public static String[][] getMissingObjectsOnTable() {
00235
00236 try {
00237
00238 queryObjectTypes();
00239
00240 Server srldbServer = new Server(modelPool);
00241 Model model = srldbServer.getModel(modelName);
00242
00243
00244 Vector<String> evidence = new Vector<String>();
00245 evidence.add("takesPartIn(P,M)");
00246
00247
00248
00249
00250
00251 for (String instance : objectTypes.keySet()) {
00252 String objType = objectTypes.get(instance);
00253 String constantType = model.getConstantType(objType);
00254 String predicate = null;
00255 if (constantType != null) {
00256 if (constantType.equalsIgnoreCase("domUtensilT"))
00257 predicate = "usesAnyIn";
00258 else if (constantType.equalsIgnoreCase("objType_g"))
00259 predicate = "consumesAnyIn";
00260 if (predicate != null) {
00261 String evidenceAtom = String.format("%s(P,%s,M)",
00262 predicate, objType);
00263 evidence.add(evidenceAtom);
00264 } else
00265 System.err
00266 .println("Warning: Evidence on instance '"
00267 + instance
00268 + "' not considered because it is neither a utensil nor a consumable object known to the model.");
00269 } else
00270 System.err
00271 .println("Warning: Evidence on instance '"
00272 + instance
00273 + "' not considered because its type is not known to the model.");
00274 }
00275
00276
00277
00278 Vector<String> queries = new Vector<String>();
00279 queries.add("usesAnyIn");
00280 queries.add("consumesAnyIn");
00281
00282
00283 Vector<InferenceResult> results = srldbServer.query(modelName,
00284 queries, evidence);
00285
00286 String[][] result = new String[results.size()][2];
00287 int i = 0;
00288 for (InferenceResult res : results) {
00289 result[i][0] = res.params[1];
00290 result[i][1] = res.probability + "";
00291 System.out.println("object: " + result[i][0] + "; prob="
00292 + result[i][1]);
00293 i++;
00294 }
00295
00296 return result;
00297
00298 } catch (IOException e) {
00299 e.printStackTrace();
00300 } catch (ParseException e) {
00301 e.printStackTrace();
00302 } catch (Exception e) {
00303 e.printStackTrace();
00304 }
00305
00306 return null;
00307 }
00308
00317 public static String[] evidenceForPredciate(Model model, String moduleName, String predicate) {
00318
00319 Set<String> result = new HashSet<String>();
00320 String[] args = getArgsForPredicate(predicate, model.getName());
00321
00322 StringBuilder query = new StringBuilder();
00323 query.append(String.format("%s:%s(", moduleName, predicate));
00324
00325 for (int i = 0; i < args.length; i++) {
00326 query.append(String.format("Arg%d", i));
00327 if (i < args.length - 1)
00328 query.append(", ");
00329 }
00330 query.append(")");
00331
00332 System.out.println("Checking evidence for: "
00333 + query.toString());
00334
00335 Map<String, Vector<Object>> answer = executeQuery(query.toString(), "");
00336
00337 if (answer.get("Arg0") == null)
00338 return new String[0];
00339
00340 String[][] resultArray = new String[answer.get("Arg0").size()][args.length];
00341
00342 for (String arg : answer.keySet()) {
00343 int argIndex = Integer.valueOf(arg.substring(3));
00344
00345 Vector<Object> values = answer.get(arg);
00346 for (int i = 0; i < values.size(); i++) {
00347 resultArray[i][argIndex] = ((String) values.get(i)).replace(
00348 "'", "");
00349 }
00350 }
00351
00352 for (int sol = 0; sol < resultArray.length; sol++) {
00353
00354 StringBuilder evidence = new StringBuilder();
00355 evidence.append(String.format("%s(", predicate));
00356
00357 boolean discard = false;
00358 for (int arg = 0; arg < resultArray[sol].length; arg++) {
00359 String value = resultArray[sol][arg];
00360
00361 String clazzURI = inferObjectClass(value);
00362 if (clazzURI != null) {
00363 value = getLocalClassName(clazzURI);
00364
00365
00366
00367 if (model.constantMapToProbCog.get(value) == null) {
00368 discard = true;
00369 break;
00370 }
00371 }
00372 evidence.append(value);
00373 if (arg < args.length - 1)
00374 evidence.append(",");
00375 }
00376 evidence.append(")");
00377 if (!discard) {
00378 result.add(evidence.toString());
00379 System.out.println(" -> found evidence: "
00380 + evidence.toString());
00381 }
00382 }
00383
00384 return result.toArray(new String[0]);
00385 }
00386
00387 public static String[][][] performInference(String modelName, String moduleName, String[] query) {
00388 try {
00389
00390 reset();
00391
00392 Server srldbServer = new Server(modelPool);
00393 Model model = srldbServer.getModel(modelName);
00394
00395
00396 Vector<String> evidence = new Vector<String>();
00397
00398
00399 String[] predicates = getPredicatesForModel(modelName);
00400
00401 for (String pred : predicates) {
00402 String[] ev = evidenceForPredciate(model, moduleName, pred);
00403
00404 for (String e : ev)
00405 evidence.add(e);
00406 }
00407
00408
00409
00410 Vector<String> queries = new Vector<String>();
00411
00412
00413 HashMap<String, Vector<Integer>> queriesPredsParams = new HashMap<String, Vector<Integer>>();
00414
00415
00416 for (String q : query) {
00417
00418
00419
00420
00421 if(q.contains("(")&&q.contains("")) {
00422 String pred = q.split("\\(")[0];
00423 String[] params = q.split("\\(")[1].split("\\)")[0].split(",");
00424
00425 Vector<Integer> qParams = new Vector<Integer>();
00426 for(int k=0;k<params.length;k++) {
00427
00428 if(params[k].contains("?"))
00429 qParams.add(k);
00430 }
00431 queriesPredsParams.put(pred, qParams);
00432 queries.add(pred);
00433
00434 } else {
00435 queriesPredsParams.put(q, new Vector<Integer>());
00436 queries.add(q);
00437 }
00438
00439
00440 }
00441
00442
00443 Vector<InferenceResult> inferenceresults = srldbServer.query(modelName, queries, evidence);
00444 Vector<Vector<String[]>> resultvector = new Vector<Vector<String[]>>();
00445
00446 String lastQuery = "";
00447 Vector<String[]> r = null;
00448
00449
00450 for (InferenceResult ires : inferenceresults) {
00451
00452 if(ires.probability==0)
00453 continue;
00454
00455
00456 if(!ires.functionName.equals(lastQuery)) {
00457
00458 if(r!=null) {
00459 resultvector.add(r);
00460 }
00461 r = new Vector<String[]>();
00462 lastQuery=ires.functionName;
00463 }
00464
00465 if(queriesPredsParams.get(ires.functionName).size()==0) {
00466
00467
00468 r.add(new String[]{ires.toString().split(" ")[1], ires.toString().split(" ")[0]});
00469
00470
00471 } else {
00472
00473
00474 String params = "";
00475 for(int k : queriesPredsParams.get(ires.functionName)) {
00476
00477 params += ires.params[k];
00478 if(k<queriesPredsParams.get(ires.functionName).size()) {
00479 params+= "_";
00480 }
00481 }
00482 r.add(new String[]{params, ""+ires.probability});
00483 }
00484 }
00485 resultvector.add(r);
00486
00487
00488 String[][][] resultarray = new String[resultvector.size()][][];
00489 for(int i=0;i<resultvector.size();i++) {
00490
00491 resultarray[i] = new String[resultvector.get(i).size()][];
00492
00493 for(int j=0;j<resultvector.get(i).size();j++) {
00494 resultarray[i][j] = resultvector.get(i).get(j);
00495 }
00496 }
00497
00498 return resultarray;
00499
00500 } catch (IOException e) {
00501 e.printStackTrace();
00502 } catch (ParseException e) {
00503 e.printStackTrace();
00504 } catch (Exception e) {
00505 e.printStackTrace();
00506 }
00507
00508 return null;
00509 }
00510
00511 private static Server getServer() {
00512 if (server == null) {
00513 try {
00514 server = new Server(modelPool);
00515 } catch (Exception e) {
00516 throw new RuntimeException(e.getMessage());
00517 }
00518 }
00519
00520 return server;
00521 }
00522
00523 public static String[] getPredicatesForModel(String modelName) {
00524 Server s = getServer();
00525
00526 Vector<String[]> predicates = s.getPredicates(modelName);
00527
00528 String[] result = new String[predicates.size()];
00529 for (int i = 0; i < result.length; i++)
00530 result[i] = predicates.get(i)[0];
00531
00532 return result;
00533 }
00534
00535 public static String[] getArgsForPredicate(String predicate,
00536 String modelName) {
00537 Server s = getServer();
00538
00539 String[] predicates = getPredicatesForModel(modelName);
00540 Vector<String[]> args = s.getPredicates(modelName);
00541
00542 for (int i = 0; i < predicates.length; i++) {
00543
00544 if (predicates[i].equals(predicate) && args.elementAt(i).length > 1) {
00545 String[] arguments = new String[args.elementAt(i).length - 1];
00546
00547 for (int j = 0; j < arguments.length; j++)
00548 arguments[j] = args.elementAt(i)[j + 1];
00549
00550 return arguments;
00551 }
00552
00553 }
00554
00555 return new String[0];
00556 }
00557
00558 public static void main(String[] args) {
00559 performInference("tableSetting_fall09", "mod_probcog_tablesetting", new String[]{"usesAnyIn(person1, ?, meal1)", "sitsAtIn(person1, ?, bla)"});
00560 }
00561 }