EMModel.java
Go to the documentation of this file.
00001 package edu.wpi.rail.jinteractiveworld.model;
00002 
00003 import edu.wpi.rail.jinteractiveworld.data.*;
00004 import edu.wpi.rail.jinteractiveworld.ros.msgs.interactiveworldmsgs.*;
00005 import edu.wpi.rail.jrosbridge.messages.geometry.Point;
00006 import org.ejml.simple.SimpleMatrix;
00007 import weka.clusterers.EM;
00008 import weka.core.*;
00009 import weka.core.Instances;
00010 import weka.filters.Filter;
00011 import weka.filters.unsupervised.attribute.Remove;
00012 import java.util.ArrayList;
00013 import java.util.List;
00014 
00024 public class EMModel implements Model {
00025 
00026         public enum RankingFunction {
00027                 CUSTOM, SCATTER_SEPARABILITY
00028         }
00029 
00030         private EM em;
00031         private Placement best;
00032         private double decisionValue, sigmaX, sigmaY, sigmaZ, sigmaTheta;
00033         private DataSet data;
00034         private RankingFunction rankingType;
00035 
00043         public EMModel(DataSet data, RankingFunction rankingType) {
00044                 this.data = data;
00045                 this.rankingType = rankingType;
00046                 this.decisionValue = Double.POSITIVE_INFINITY;
00047                 this.train();
00048         }
00049 
00056         @Override
00057         public void add(DataPoint point) {
00058                 this.data.add(point);
00059                 // retrain
00060                 this.train();
00061         }
00062 
00068         public int size() {
00069                 return this.data.size();
00070         }
00071 
00077         @Override
00078         public List<DataPoint> getData() {
00079                 // copy into an array list
00080                 ArrayList<DataPoint> list = new ArrayList<DataPoint>();
00081                 for (int i = 0; i < this.data.size(); i++) {
00082                         list.add(this.data.get(i));
00083                 }
00084                 return list;
00085         }
00086 
00092         @Override
00093         public String getReferenceFrame() {
00094                 return this.data.getReferenceFrame();
00095         }
00096 
00102         @Override
00103         public Item getItem() {
00104                 return this.data.getItem();
00105         }
00106 
00112         @Override
00113         public Room getRoom() {
00114                 return this.data.getRoom();
00115         }
00116 
00122         @Override
00123         public Surface getSurface() {
00124                 return this.data.getSurface();
00125         }
00126 
00130         @Override
00131         public void train() {
00132                 try {
00133                         this.em = new EM();
00134                         this.best = null;
00135                         this.decisionValue = Double.POSITIVE_INFINITY;
00136                         this.sigmaX = 0;
00137                         this.sigmaY = 0;
00138                         this.sigmaZ = 0;
00139                         this.sigmaTheta = 0;
00140 
00141                         // get the instances
00142                         Instances instances = this.data.toInstances();
00143 
00144                         // remove z for now
00145                         Remove rm = new Remove();
00146                         rm.setAttributeIndicesArray(new int[] { DataSet.Z_ATTRIBUTE.index() });
00147                         rm.setInputFormat(instances);
00148                         Instances newInstances = Filter.useFilter(instances, rm);
00149 
00150                         // run EM
00151                         this.em.buildClusterer(newInstances);
00152                         // get the results
00153                         double clusterData[][][] = this.em.getClusterModelsNumericAtts();
00154                         double[] priors = this.em.clusterPriors();
00155 
00156 
00157                         if (this.rankingType.equals(RankingFunction.CUSTOM)) {
00158 
00159                                 // cluster each point
00160                                 @SuppressWarnings("unchecked")
00161                                 ArrayList<Instance>[] clusters = (ArrayList<Instance>[]) new ArrayList[clusterData.length];
00162                                 for (int i = 0; i < clusters.length; i++) {
00163                                         clusters[i] = new ArrayList<Instance>();
00164                                 }
00165                                 for (int i = 0; i < newInstances.numInstances(); i++) {
00166                                         Instance inst = newInstances.instance(i);
00167                                         int clust = this.em.clusterInstance(inst);
00168                                         clusters[clust].add(inst);
00169                                 }
00170 
00171                                 // rank the clusters
00172                                 for (int m = 0; m < clusters.length; m++) {
00173                                         double distance = 0;
00174                                         ArrayList<Instance> curInstances = clusters[m];
00175                                         for (int i = 0; i < curInstances.size(); i++) {
00176                                                 Instance instI = curInstances.get(i);
00177                                                 for (int j = 0; j < curInstances.size(); j++) {
00178                                                         if (i != j) {
00179                                                                 Instance instJ = curInstances.get(j);
00180                                                                 // get each attribute
00181                                                                 double sum = 0;
00182                                                                 for (int k = 0; k < instI.numAttributes(); k++) {
00183                                                                         sum += Math.pow(
00184                                                                                         instI.value(k) - instJ.value(k), 2.0);
00185                                                                 }
00186                                                                 // get the distance
00187                                                                 distance += Math.sqrt(sum);
00188                                                         }
00189                                                 }
00190                                         }
00191                                         // average
00192                                         double ranking = distance
00193                                                         / (curInstances.size() * (curInstances.size() - 1)) * 1.0
00194                                                         / (((double) curInstances.size()) / ((double) this.size()));
00195 
00196                                         // check for a new best
00197                                         if (ranking < this.decisionValue) {
00198                                                 this.decisionValue = ranking;
00199                                                 double x = clusterData[m][0][0];
00200                                                 this.sigmaX = clusterData[m][0][1];
00201                                                 double y = clusterData[m][1][0];
00202                                                 this.sigmaY = clusterData[m][1][1];
00203                                                 double z = 0.0;
00204                                                 this.sigmaZ = 0.0;
00205                                                 double theta = clusterData[m][2][0];
00206                                                 this.sigmaTheta = clusterData[m][2][1];
00207                                                 this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00208                                         }
00209                                 }
00210                         } else if (this.rankingType.equals(RankingFunction.SCATTER_SEPARABILITY)) {
00211                                 int n = newInstances.numAttributes();
00212                                 // priors
00213 //                              SimpleMatrix sw = new SimpleMatrix(n, n);
00214 //                              SimpleMatrix mo = new SimpleMatrix(n, 1);
00215 //                              SimpleMatrix sb = new SimpleMatrix(n, n);
00216 
00217                                 // go through each cluster for sw and mo
00218                                 for (int i = 0; i < clusterData.length; i++) {
00219                                         // covariance matrix
00220                                         double[][] covArray = new double[n][n];
00221                                         // mu for the cluster
00222 //                                      double [][] muArray = new double[n][1];
00223                                         for (int j = 0; j < n; j++) {
00224                                                 covArray[j][j] = clusterData[i][j][1];
00225 //                                              muArray[j][0] = clusterData[i][j][0];
00226                                         }
00227                                         SimpleMatrix cov = new SimpleMatrix(covArray);
00228                                         SimpleMatrix covScale = cov.scale(priors[i]);
00229                                         double trace = covScale.trace();
00230                                         if (trace < this.decisionValue) {
00231                                                 this.decisionValue = trace;
00232                                                 double x = clusterData[i][0][0];
00233                                                 this.sigmaX = clusterData[i][0][1];
00234                                                 double y = clusterData[i][1][0];
00235                                                 this.sigmaY = clusterData[i][1][1];
00236                                                 double z = 0.0;
00237                                                 this.sigmaZ = 0.0;
00238                                                 double theta = clusterData[i][2][0];
00239                                                 this.sigmaTheta = clusterData[i][2][1];
00240                                                 this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00241                                         }
00242 //                                      SimpleMatrix mu = new SimpleMatrix(muArray);
00243 //                                      SimpleMatrix muScale = mu.scale(priors[i]);
00244 //                                      mo = mo.plus(muScale);
00245                                 }
00246 
00247 //                              // go through each cluster for sb
00248 //                              for (int i = 0; i < clusterData.length; i++) {
00249 //                                      // mu for the cluster
00250 //                                      double [][] muArray = new double[n][1];
00251 //                                      for (int j = 0; j < n; j++) {
00252 //                                              muArray[j][0] = clusterData[i][j][0];
00253 //                                      }
00254 //                                      SimpleMatrix mu = new SimpleMatrix(muArray);
00255 //                                      SimpleMatrix meanDiff = mu.minus(mo);
00256 //                                      SimpleMatrix meanDiffTrans = meanDiff.transpose();
00257 //                                      SimpleMatrix product = meanDiff.mult(meanDiffTrans);
00258 //                                      SimpleMatrix scale = product.scale(priors[i]);
00259 //                                      sb = sb.plus(scale);
00260 //                              }
00261 
00262                                 // trace as the value
00263 //                              this.decisionValue = sw.invert().mult(sb).trace();
00264                                 //this.decisionValue = sw.trace();
00265                                 System.out.println(this.decisionValue); // + " -- " + sw.invert().trace());
00266                         }
00267 
00268                 } catch (Exception e) {
00269                         System.err.println("[ERROR]: Could not train model: "
00270                                         + e.getMessage());
00271                         e.printStackTrace();
00272                 }
00273         }
00274 
00282         @Override
00283         public Placement getPlacementLocation() {
00284                 return this.best;
00285         }
00286 
00293         @Override
00294         public double getDecisionValue() {
00295                 return this.decisionValue;
00296         }
00297 
00303         @Override
00304         public double getSigmaX() {
00305                 return this.sigmaX;
00306         }
00307 
00313         @Override
00314         public double getSigmaY() {
00315                 return this.sigmaY;
00316         }
00317 
00323         @Override
00324         public double getSigmaZ() {
00325                 return this.sigmaZ;
00326         }
00327 
00333         @Override
00334         public double getSigmaTheta() {
00335                 return this.sigmaTheta;
00336         }
00337 }


jinteractiveworld
Author(s): Russell Toris
autogenerated on Thu Jun 6 2019 21:34:23