Go to the documentation of this file.00001 package edu.wpi.rail.jinteractiveworld.model;
00002
00003 import edu.wpi.rail.jinteractiveworld.data.*;
00004 import edu.wpi.rail.jinteractiveworld.ros.msgs.interactiveworldmsgs.*;
00005 import edu.wpi.rail.jrosbridge.messages.geometry.Point;
00006 import org.ejml.simple.SimpleMatrix;
00007 import weka.clusterers.EM;
00008 import weka.core.*;
00009 import weka.core.Instances;
00010 import weka.filters.Filter;
00011 import weka.filters.unsupervised.attribute.Remove;
00012 import java.util.ArrayList;
00013 import java.util.List;
00014
00024 public class EMModel implements Model {
00025
00026 public enum RankingFunction {
00027 CUSTOM, SCATTER_SEPARABILITY
00028 }
00029
00030 private EM em;
00031 private Placement best;
00032 private double decisionValue, sigmaX, sigmaY, sigmaZ, sigmaTheta;
00033 private DataSet data;
00034 private RankingFunction rankingType;
00035
00043 public EMModel(DataSet data, RankingFunction rankingType) {
00044 this.data = data;
00045 this.rankingType = rankingType;
00046 this.decisionValue = Double.POSITIVE_INFINITY;
00047 this.train();
00048 }
00049
00056 @Override
00057 public void add(DataPoint point) {
00058 this.data.add(point);
00059
00060 this.train();
00061 }
00062
00068 public int size() {
00069 return this.data.size();
00070 }
00071
00077 @Override
00078 public List<DataPoint> getData() {
00079
00080 ArrayList<DataPoint> list = new ArrayList<DataPoint>();
00081 for (int i = 0; i < this.data.size(); i++) {
00082 list.add(this.data.get(i));
00083 }
00084 return list;
00085 }
00086
00092 @Override
00093 public String getReferenceFrame() {
00094 return this.data.getReferenceFrame();
00095 }
00096
00102 @Override
00103 public Item getItem() {
00104 return this.data.getItem();
00105 }
00106
00112 @Override
00113 public Room getRoom() {
00114 return this.data.getRoom();
00115 }
00116
00122 @Override
00123 public Surface getSurface() {
00124 return this.data.getSurface();
00125 }
00126
00130 @Override
00131 public void train() {
00132 try {
00133 this.em = new EM();
00134 this.best = null;
00135 this.decisionValue = Double.POSITIVE_INFINITY;
00136 this.sigmaX = 0;
00137 this.sigmaY = 0;
00138 this.sigmaZ = 0;
00139 this.sigmaTheta = 0;
00140
00141
00142 Instances instances = this.data.toInstances();
00143
00144
00145 Remove rm = new Remove();
00146 rm.setAttributeIndicesArray(new int[] { DataSet.Z_ATTRIBUTE.index() });
00147 rm.setInputFormat(instances);
00148 Instances newInstances = Filter.useFilter(instances, rm);
00149
00150
00151 this.em.buildClusterer(newInstances);
00152
00153 double clusterData[][][] = this.em.getClusterModelsNumericAtts();
00154 double[] priors = this.em.clusterPriors();
00155
00156
00157 if (this.rankingType.equals(RankingFunction.CUSTOM)) {
00158
00159
00160 @SuppressWarnings("unchecked")
00161 ArrayList<Instance>[] clusters = (ArrayList<Instance>[]) new ArrayList[clusterData.length];
00162 for (int i = 0; i < clusters.length; i++) {
00163 clusters[i] = new ArrayList<Instance>();
00164 }
00165 for (int i = 0; i < newInstances.numInstances(); i++) {
00166 Instance inst = newInstances.instance(i);
00167 int clust = this.em.clusterInstance(inst);
00168 clusters[clust].add(inst);
00169 }
00170
00171
00172 for (int m = 0; m < clusters.length; m++) {
00173 double distance = 0;
00174 ArrayList<Instance> curInstances = clusters[m];
00175 for (int i = 0; i < curInstances.size(); i++) {
00176 Instance instI = curInstances.get(i);
00177 for (int j = 0; j < curInstances.size(); j++) {
00178 if (i != j) {
00179 Instance instJ = curInstances.get(j);
00180
00181 double sum = 0;
00182 for (int k = 0; k < instI.numAttributes(); k++) {
00183 sum += Math.pow(
00184 instI.value(k) - instJ.value(k), 2.0);
00185 }
00186
00187 distance += Math.sqrt(sum);
00188 }
00189 }
00190 }
00191
00192 double ranking = distance
00193 / (curInstances.size() * (curInstances.size() - 1)) * 1.0
00194 / (((double) curInstances.size()) / ((double) this.size()));
00195
00196
00197 if (ranking < this.decisionValue) {
00198 this.decisionValue = ranking;
00199 double x = clusterData[m][0][0];
00200 this.sigmaX = clusterData[m][0][1];
00201 double y = clusterData[m][1][0];
00202 this.sigmaY = clusterData[m][1][1];
00203 double z = 0.0;
00204 this.sigmaZ = 0.0;
00205 double theta = clusterData[m][2][0];
00206 this.sigmaTheta = clusterData[m][2][1];
00207 this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00208 }
00209 }
00210 } else if (this.rankingType.equals(RankingFunction.SCATTER_SEPARABILITY)) {
00211 int n = newInstances.numAttributes();
00212
00213
00214
00215
00216
00217
00218 for (int i = 0; i < clusterData.length; i++) {
00219
00220 double[][] covArray = new double[n][n];
00221
00222
00223 for (int j = 0; j < n; j++) {
00224 covArray[j][j] = clusterData[i][j][1];
00225
00226 }
00227 SimpleMatrix cov = new SimpleMatrix(covArray);
00228 SimpleMatrix covScale = cov.scale(priors[i]);
00229 double trace = covScale.trace();
00230 if (trace < this.decisionValue) {
00231 this.decisionValue = trace;
00232 double x = clusterData[i][0][0];
00233 this.sigmaX = clusterData[i][0][1];
00234 double y = clusterData[i][1][0];
00235 this.sigmaY = clusterData[i][1][1];
00236 double z = 0.0;
00237 this.sigmaZ = 0.0;
00238 double theta = clusterData[i][2][0];
00239 this.sigmaTheta = clusterData[i][2][1];
00240 this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00241 }
00242
00243
00244
00245 }
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265 System.out.println(this.decisionValue);
00266 }
00267
00268 } catch (Exception e) {
00269 System.err.println("[ERROR]: Could not train model: "
00270 + e.getMessage());
00271 e.printStackTrace();
00272 }
00273 }
00274
00282 @Override
00283 public Placement getPlacementLocation() {
00284 return this.best;
00285 }
00286
00293 @Override
00294 public double getDecisionValue() {
00295 return this.decisionValue;
00296 }
00297
00303 @Override
00304 public double getSigmaX() {
00305 return this.sigmaX;
00306 }
00307
00313 @Override
00314 public double getSigmaY() {
00315 return this.sigmaY;
00316 }
00317
00323 @Override
00324 public double getSigmaZ() {
00325 return this.sigmaZ;
00326 }
00327
00333 @Override
00334 public double getSigmaTheta() {
00335 return this.sigmaTheta;
00336 }
00337 }