EMModel.java
Go to the documentation of this file.
00001 package edu.wpi.rail.jinteractiveworld.model;
00002 
00003 import edu.wpi.rail.jinteractiveworld.data.*;
00004 import edu.wpi.rail.jinteractiveworld.ros.msgs.interactiveworldmsgs.*;
00005 import edu.wpi.rail.jrosbridge.messages.geometry.Point;
00006 import weka.clusterers.EM;
00007 import weka.core.*;
00008 import weka.core.Instances;
00009 import weka.filters.Filter;
00010 import weka.filters.unsupervised.attribute.Remove;
00011 import java.util.ArrayList;
00012 import java.util.List;
00013 
00023 public class EMModel implements Model {
00024 
00025         private EM em;
00026         private Placement best;
00027         private double decisionValue, sigmaX, sigmaY, sigmaZ, sigmaTheta;
00028         private DataSet data;
00029 
00037         public EMModel(DataSet data) {
00038                 this.data = data;
00039                 this.train();
00040         }
00041 
00048         @Override
00049         public void add(DataPoint point) {
00050                 this.data.add(point);
00051                 // retrain
00052                 this.train();
00053         }
00054 
00060         public int size() {
00061                 return this.data.size();
00062         }
00063 
00069         @Override
00070         public List<DataPoint> getData() {
00071                 // copy into an array list
00072                 ArrayList<DataPoint> list = new ArrayList<DataPoint>();
00073                 for (int i = 0; i < this.data.size(); i++) {
00074                         list.add(this.data.get(i));
00075                 }
00076                 return list;
00077         }
00078 
00084         @Override
00085         public String getReferenceFrame() {
00086                 return this.data.getReferenceFrame();
00087         }
00088 
00094         @Override
00095         public Item getItem() {
00096                 return this.data.getItem();
00097         }
00098 
00104         @Override
00105         public Room getRoom() {
00106                 return this.data.getRoom();
00107         }
00108 
00114         @Override
00115         public Surface getSurface() {
00116                 return this.data.getSurface();
00117         }
00118 
00122         @Override
00123         public void train() {
00124                 try {
00125                         this.em = new EM();
00126                         this.best = null;
00127                         this.decisionValue = Double.POSITIVE_INFINITY;
00128                         this.sigmaX = 0;
00129                         this.sigmaY = 0;
00130                         this.sigmaZ = 0;
00131                         this.sigmaTheta = 0;
00132 
00133                         // get the instances
00134                         Instances instances = this.data.toInstances();
00135 
00136                         // remove z for now
00137                         Remove rm = new Remove();
00138                         rm.setAttributeIndicesArray(new int[] { DataSet.Z_ATTRIBUTE.index() });
00139                         rm.setInputFormat(instances);
00140                         Instances newInstances = Filter.useFilter(instances, rm);
00141 
00142                         // run EM
00143                         this.em.buildClusterer(newInstances);
00144 
00145                         // get the results
00146                         double clusterData[][][] = this.em.getClusterModelsNumericAtts();
00147 
00148                         // cluster each point
00149                         @SuppressWarnings("unchecked")
00150                         ArrayList<Instance>[] clusters = (ArrayList<Instance>[]) new ArrayList[clusterData.length];
00151                         for (int i = 0; i < clusters.length; i++) {
00152                                 clusters[i] = new ArrayList<Instance>();
00153                         }
00154                         for (int i = 0; i < newInstances.numInstances(); i++) {
00155                                 Instance inst = newInstances.instance(i);
00156                                 int clust = this.em.clusterInstance(inst);
00157                                 clusters[clust].add(inst);
00158                         }
00159 
00160                         // determine the densest cluster
00161                         for (int m = 0; m < clusters.length; m++) {
00162                                 ArrayList<Instance> curInsts = clusters[m];
00163                                 double distance = 0;
00164                                 for (int i = 0; i < curInsts.size(); i++) {
00165                                         Instance instI = curInsts.get(i);
00166                                         for (int j = 0; j < curInsts.size(); j++) {
00167                                                 if (i != j) {
00168                                                         Instance instJ = curInsts.get(j);
00169                                                         // get each attribute
00170                                                         double sum = 0;
00171                                                         for (int k = 0; k < instI.numAttributes(); k++) {
00172                                                                 sum += Math.pow(
00173                                                                                 instI.value(k) - instJ.value(k), 2.0);
00174                                                         }
00175                                                         // get the distance
00176                                                         distance += Math.sqrt(sum);
00177                                                 }
00178                                         }
00179                                 }
00180                                 // average
00181                                 double density = distance
00182                                                 / (curInsts.size() * (curInsts.size() - 1)) * 1.0
00183                                                 / (((double) curInsts.size()) / ((double) this.size()));
00184 
00185                                 // check for a new best
00186                                 if (density < this.decisionValue) {
00187                                         this.decisionValue = density;
00188                                         double x = clusterData[m][0][0];
00189                                         this.sigmaX = clusterData[m][0][1];
00190                                         double y = clusterData[m][1][0];
00191                                         this.sigmaY = clusterData[m][1][1];
00192                                         double z = 0.0;
00193                                         this.sigmaZ = 0.0;
00194                                         double theta = clusterData[m][2][0];
00195                                         this.sigmaTheta = clusterData[m][2][1];
00196                                         this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00197                                 }
00198                         }
00199                 } catch (Exception e) {
00200                         System.err.println("[ERROR]: Could not train model: "
00201                                         + e.getMessage());
00202                 }
00203         }
00204 
00212         @Override
00213         public Placement getPlacementLocation() {
00214                 return this.best;
00215         }
00216 
00223         @Override
00224         public double getDecisionValue() {
00225                 return this.decisionValue;
00226         }
00227 
00233         @Override
00234         public double getSigmaX() {
00235                 return this.sigmaX;
00236         }
00237 
00243         @Override
00244         public double getSigmaY() {
00245                 return this.sigmaY;
00246         }
00247 
00253         @Override
00254         public double getSigmaZ() {
00255                 return this.sigmaZ;
00256         }
00257 
00263         @Override
00264         public double getSigmaTheta() {
00265                 return this.sigmaTheta;
00266         }
00267 }


jinteractiveworld
Author(s): Russell Toris
autogenerated on Sun Dec 14 2014 11:27:03