Go to the documentation of this file.00001 package edu.wpi.rail.jinteractiveworld.model;
00002
00003 import edu.wpi.rail.jinteractiveworld.data.*;
00004 import edu.wpi.rail.jinteractiveworld.ros.msgs.interactiveworldmsgs.*;
00005 import edu.wpi.rail.jrosbridge.messages.geometry.Point;
00006 import weka.clusterers.EM;
00007 import weka.core.*;
00008 import weka.core.Instances;
00009 import weka.filters.Filter;
00010 import weka.filters.unsupervised.attribute.Remove;
00011 import java.util.ArrayList;
00012 import java.util.List;
00013
00023 public class EMModel implements Model {
00024
00025 private EM em;
00026 private Placement best;
00027 private double decisionValue, sigmaX, sigmaY, sigmaZ, sigmaTheta;
00028 private DataSet data;
00029
00037 public EMModel(DataSet data) {
00038 this.data = data;
00039 this.train();
00040 }
00041
00048 @Override
00049 public void add(DataPoint point) {
00050 this.data.add(point);
00051
00052 this.train();
00053 }
00054
00060 public int size() {
00061 return this.data.size();
00062 }
00063
00069 @Override
00070 public List<DataPoint> getData() {
00071
00072 ArrayList<DataPoint> list = new ArrayList<DataPoint>();
00073 for (int i = 0; i < this.data.size(); i++) {
00074 list.add(this.data.get(i));
00075 }
00076 return list;
00077 }
00078
00084 @Override
00085 public String getReferenceFrame() {
00086 return this.data.getReferenceFrame();
00087 }
00088
00094 @Override
00095 public Item getItem() {
00096 return this.data.getItem();
00097 }
00098
00104 @Override
00105 public Room getRoom() {
00106 return this.data.getRoom();
00107 }
00108
00114 @Override
00115 public Surface getSurface() {
00116 return this.data.getSurface();
00117 }
00118
00122 @Override
00123 public void train() {
00124 try {
00125 this.em = new EM();
00126 this.best = null;
00127 this.decisionValue = Double.POSITIVE_INFINITY;
00128 this.sigmaX = 0;
00129 this.sigmaY = 0;
00130 this.sigmaZ = 0;
00131 this.sigmaTheta = 0;
00132
00133
00134 Instances instances = this.data.toInstances();
00135
00136
00137 Remove rm = new Remove();
00138 rm.setAttributeIndicesArray(new int[] { DataSet.Z_ATTRIBUTE.index() });
00139 rm.setInputFormat(instances);
00140 Instances newInstances = Filter.useFilter(instances, rm);
00141
00142
00143 this.em.buildClusterer(newInstances);
00144
00145
00146 double clusterData[][][] = this.em.getClusterModelsNumericAtts();
00147
00148
00149 @SuppressWarnings("unchecked")
00150 ArrayList<Instance>[] clusters = (ArrayList<Instance>[]) new ArrayList[clusterData.length];
00151 for (int i = 0; i < clusters.length; i++) {
00152 clusters[i] = new ArrayList<Instance>();
00153 }
00154 for (int i = 0; i < newInstances.numInstances(); i++) {
00155 Instance inst = newInstances.instance(i);
00156 int clust = this.em.clusterInstance(inst);
00157 clusters[clust].add(inst);
00158 }
00159
00160
00161 for (int m = 0; m < clusters.length; m++) {
00162 ArrayList<Instance> curInsts = clusters[m];
00163 double distance = 0;
00164 for (int i = 0; i < curInsts.size(); i++) {
00165 Instance instI = curInsts.get(i);
00166 for (int j = 0; j < curInsts.size(); j++) {
00167 if (i != j) {
00168 Instance instJ = curInsts.get(j);
00169
00170 double sum = 0;
00171 for (int k = 0; k < instI.numAttributes(); k++) {
00172 sum += Math.pow(
00173 instI.value(k) - instJ.value(k), 2.0);
00174 }
00175
00176 distance += Math.sqrt(sum);
00177 }
00178 }
00179 }
00180
00181 double density = distance
00182 / (curInsts.size() * (curInsts.size() - 1)) * 1.0
00183 / (((double) curInsts.size()) / ((double) this.size()));
00184
00185
00186 if (density < this.decisionValue) {
00187 this.decisionValue = density;
00188 double x = clusterData[m][0][0];
00189 this.sigmaX = clusterData[m][0][1];
00190 double y = clusterData[m][1][0];
00191 this.sigmaY = clusterData[m][1][1];
00192 double z = 0.0;
00193 this.sigmaZ = 0.0;
00194 double theta = clusterData[m][2][0];
00195 this.sigmaTheta = clusterData[m][2][1];
00196 this.best = new Placement(this.getItem(), this.getRoom(), this.getSurface(), this.getReferenceFrame(), new Point(x, y, z), theta);
00197 }
00198 }
00199 } catch (Exception e) {
00200 System.err.println("[ERROR]: Could not train model: "
00201 + e.getMessage());
00202 }
00203 }
00204
00212 @Override
00213 public Placement getPlacementLocation() {
00214 return this.best;
00215 }
00216
00223 @Override
00224 public double getDecisionValue() {
00225 return this.decisionValue;
00226 }
00227
00233 @Override
00234 public double getSigmaX() {
00235 return this.sigmaX;
00236 }
00237
00243 @Override
00244 public double getSigmaY() {
00245 return this.sigmaY;
00246 }
00247
00253 @Override
00254 public double getSigmaZ() {
00255 return this.sigmaZ;
00256 }
00257
00263 @Override
00264 public double getSigmaTheta() {
00265 return this.sigmaTheta;
00266 }
00267 }