00001
00002
00003
00004 package edu.tum.cs.bayesnets.core;
00005
00006 import java.util.Arrays;
00007
00011 public interface DiscretizationFilter {
00016 public String[] getOutputValues();
00022 public String getValueForContinuous(double continuous);
00023
00029 public double getExampleValue(int bin);
00030
00037 public double[] getIntervals(int bin);
00038
00045 public void addOutputValue(String outputValue);
00046
00051 public class Default implements DiscretizationFilter {
00052
00056 protected double[] splitPoints;
00060 protected String[] outputValues;
00061
00066 public Default(double[] splitPoints) {
00067 init(splitPoints);
00068 }
00069
00074 protected void init(double[] splitPoints) {
00075 this.splitPoints = new double[splitPoints.length];
00076 System.arraycopy(splitPoints, 0, this.splitPoints, 0, splitPoints.length);
00077 Arrays.sort(this.splitPoints);
00078 outputValues = new String[this.splitPoints.length + 1];
00079 if (this.splitPoints.length == 0) {
00080 outputValues[0] = "-inf < && <= inf";
00081 return;
00082 }
00083 outputValues[0]=String.format("<= %.2e", this.splitPoints[0]);
00084 for (int i = 0; i < this.splitPoints.length - 1; i++) {
00085 outputValues[i+1] = String.format( "%.2e < && <= %.2e", this.splitPoints[i], this.splitPoints[i+1]);
00086 }
00087 outputValues[this.splitPoints.length] = String.format("> %.2e", this.splitPoints[splitPoints.length-1]);
00088 }
00089
00090
00091
00092
00093 public double getExampleValue(int bin) {
00094 if (splitPoints.length < bin)
00095 throw new IllegalArgumentException("Value out of range: "+bin+">"+splitPoints.length);
00096 if (splitPoints.length == 0)
00097 return 0.0;
00098 if (bin == 0) {
00099 return 0.9*splitPoints[0];
00100 }
00101 if (bin == splitPoints.length) {
00102 return 1.1*splitPoints[splitPoints.length-1];
00103 }
00104 return 0.5*(splitPoints[bin]+splitPoints[bin-1]);
00105 }
00106
00107
00108
00109
00110 public double[] getIntervals(int bin) {
00111 if (splitPoints.length < bin)
00112 throw new IllegalArgumentException("Value out of range: "+bin+">"+splitPoints.length);
00113 if (splitPoints.length == 0)
00114 return new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
00115 if (bin == 0)
00116 return new double[] {Double.NEGATIVE_INFINITY, splitPoints[0]};
00117 if (bin == splitPoints.length)
00118 return new double[] {splitPoints[splitPoints.length-1], Double.POSITIVE_INFINITY};
00119 return new double[] {splitPoints[bin-1], splitPoints[bin]};
00120 }
00121
00126 public void addSplitPoint(double point) {
00127 double[] newSplitPoints = new double[splitPoints.length+1];
00128 System.arraycopy(splitPoints, 0, newSplitPoints, 0, splitPoints.length);
00129 newSplitPoints[splitPoints.length] = point;
00130 init(newSplitPoints);
00131 }
00132
00138 public boolean containsSplitPoint(double splitPoint) {
00139 int index = Arrays.binarySearch(splitPoints, splitPoint);
00140 if (index > splitPoints.length || index < 0 || splitPoints[index]!=splitPoint)
00141 return false;
00142 else
00143 return true;
00144 }
00145
00146
00147
00148
00149 public void addOutputValue(String outputValue) {
00150 outputValue = outputValue.trim();
00151 if (outputValue.equals("-inf < && <= inf")) {
00152 splitPoints = new double[0];
00153 } else if (outputValue.startsWith("<=")) {
00154 double splitPoint = Double.parseDouble(outputValue.substring(2).trim());
00155 if (!containsSplitPoint(splitPoint))
00156 addSplitPoint(splitPoint);
00157 } else if (outputValue.startsWith(">")) {
00158 double splitPoint = Double.parseDouble(outputValue.substring(1).trim());
00159 if (!containsSplitPoint(splitPoint))
00160 addSplitPoint(splitPoint);
00161 } else if (outputValue.contains("&&")) {
00162 String[] parts = outputValue.split("&&");
00163 if (parts.length != 2)
00164 throw new IllegalArgumentException("Unable to parse output value "+outputValue+"!");
00165 double splitPoint = Double.parseDouble(parts[0].split("<")[0].trim());
00166 if (!containsSplitPoint(splitPoint))
00167 addSplitPoint(splitPoint);
00168 splitPoint = Double.parseDouble(parts[1].split("<=")[1].trim());
00169 if (!containsSplitPoint(splitPoint))
00170 addSplitPoint(splitPoint);
00171 } else {
00172 throw new IllegalArgumentException("Unable to parse output value "+outputValue+"!");
00173 }
00174 }
00175
00176
00177
00178
00179 public String[] getOutputValues() {
00180 return outputValues;
00181 }
00182
00190 public String getValueForContinuous(double continuous) {
00191 for (int i = 0; i < splitPoints.length; i++) {
00192 double delta = splitPoints[i];
00193 if (i>0) {
00194 delta = splitPoints[i]-splitPoints[i-1];
00195 } else if (i < splitPoints.length-1) {
00196 delta = splitPoints[i+1]-splitPoints[i];
00197 }
00198 if (continuous <= splitPoints[i]+delta*1e-5)
00199 return outputValues[i];
00200 }
00201 return outputValues[splitPoints.length];
00202 }
00203
00204
00205
00206
00207 public String toString() {
00208 return "Discretized(splitPoints = "+Arrays.toString(splitPoints)+"\n\tOutputValues: "+Arrays.toString(outputValues);
00209 }
00210 }
00211 }