00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.inference;
00008
00009 import java.io.PrintStream;
00010 import java.lang.reflect.InvocationTargetException;
00011 import java.util.Vector;
00012
00013 import umontreal.iro.lecuyer.probdist.BetaDist;
00014
00015 public abstract class BasicSampledDistribution implements IParameterHandler {
00020 public double[][] values = null;
00024 public Double Z = null;
00029 public Double confidenceLevel = null;
00030 public ParameterHandler paramHandler;
00031
00032 public BasicSampledDistribution() throws Exception {
00033 paramHandler = new ParameterHandler(this);
00034 paramHandler.add("confidenceLevel", "setConfidenceLevel");
00035 }
00036
00037 public double getProbability(int varIdx, int domainIdx) {
00038 return values[varIdx][domainIdx] / Z;
00039 }
00040
00046 public double[] getDistribution(int varIdx) {
00047 double[] ret = new double[values[varIdx].length];
00048 for(int i = 0; i < ret.length; i++)
00049 ret[i] = values[varIdx][i] / Z;
00050 return ret;
00051 }
00052
00053 public void print(PrintStream out) {
00054 for(int i = 0; i < values.length; i++) {
00055 printVariableDistribution(out, i);
00056 }
00057 }
00058
00059 public abstract Integer getNumSamples();
00060
00061 public void printVariableDistribution(PrintStream out, int idx) {
00062 out.println(getVariableName(idx) + ":");
00063 String[] domain = getDomain(idx);
00064 for(int j = 0; j < domain.length; j++) {
00065 double prob = values[idx][j] / Z;
00066 if(confidenceLevel == null)
00067 out.printf(" %.4f %s\n", prob, domain[j]);
00068 else {
00069 out.printf(" %.4f %s %s", prob, getConfidenceInterval(idx, j).toString());
00070 }
00071 }
00072 }
00073
00074 public ConfidenceInterval getConfidenceInterval(int varIdx, int domIdx) {
00075 return new ConfidenceInterval(varIdx, domIdx);
00076 }
00077
00078 public abstract String getVariableName(int idx);
00079 public abstract int getVariableIndex(String name);
00080 public abstract String[] getDomain(int idx);
00081
00082 public int getDomainSize(int idx) {
00083 return values[idx].length;
00084 }
00085
00086 public GeneralSampledDistribution toGeneralDistribution() throws Exception {
00087 int numVars = values.length;
00088 String[] varNames = new String[numVars];
00089 String[][] domains = new String[numVars][];
00090 for(int i = 0; i < numVars; i++) {
00091 varNames[i] = getVariableName(i);
00092 domains[i] = getDomain(i);
00093 }
00094 return new GeneralSampledDistribution(this.values, this.Z, varNames, domains);
00095 }
00096
00103 public double getMSE(BasicSampledDistribution d) throws Exception {
00104 return compare(new MeanSquaredError(this), d);
00105 }
00106
00107 public double getHellingerDistance(BasicSampledDistribution d) throws Exception {
00108 return compare(new HellingerDistance(this), d);
00109 }
00110
00111 public double compare(DistributionEntryComparison dec, BasicSampledDistribution otherDist) throws Exception {
00112 DistributionComparison dc = new DistributionComparison(this, otherDist);
00113 dc.addEntryComparison(dec);
00114 dc.compare();
00115 return dec.getResult();
00116 }
00117
00118 public void setConfidenceLevel(Double confidenceLevel) {
00119 this.confidenceLevel = confidenceLevel;
00120 }
00121
00122 public boolean usesConfidenceComputation() {
00123 return confidenceLevel != null;
00124 }
00125
00126 public ParameterHandler getParameterHandler() {
00127 return paramHandler;
00128 }
00129
00130 public class ConfidenceInterval {
00131 public double lowerEnd, upperEnd;
00132 protected int precisionDigits = 4;
00133
00134 public ConfidenceInterval(int varIdx, int domIdx) {
00135 int numSamples = getNumSamples();
00136 double p = values[varIdx][domIdx] / Z;
00137 double alpha = p * numSamples;
00138 double beta = numSamples - alpha;
00139 alpha += 1;
00140 beta += 1;
00141 double confAlpha = 1-confidenceLevel;
00142 lowerEnd = BetaDist.inverseF(alpha, beta, precisionDigits, confAlpha/2);
00143 upperEnd = BetaDist.inverseF(alpha, beta, precisionDigits, 1-confAlpha/2);
00144 if(p > upperEnd) {
00145 lowerEnd = BetaDist.inverseF(alpha, beta, precisionDigits, confAlpha);
00146 upperEnd = 1.0;
00147 }
00148 else if(p < lowerEnd) {
00149 lowerEnd = 0.0;
00150 upperEnd = BetaDist.inverseF(alpha, beta, precisionDigits, 1-confAlpha);
00151 }
00152 }
00153
00154 public double getSize() {
00155 return upperEnd-lowerEnd;
00156 }
00157
00158 public String toString() {
00159 return String.format(String.format("[%%.%df;%%.%df] %%.4f", precisionDigits, precisionDigits), lowerEnd, upperEnd, getSize());
00160 }
00161 }
00162
00163 public static class DistributionComparison {
00164 protected BasicSampledDistribution referenceDist, otherDist;
00165 protected Vector<DistributionEntryComparison> processors;
00166
00167 public DistributionComparison(BasicSampledDistribution referenceDist, BasicSampledDistribution otherDist) {
00168 this.referenceDist = referenceDist;
00169 this.otherDist = otherDist;
00170 processors = new Vector<DistributionEntryComparison>();
00171 }
00172
00173 public void addEntryComparison(DistributionEntryComparison c) {
00174 processors.add(c);
00175 }
00176
00177 public void addEntryComparison(Class<? extends DistributionEntryComparison> c) throws IllegalArgumentException, SecurityException, InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException {
00178 addEntryComparison(c.getConstructor(BasicSampledDistribution.class).newInstance(referenceDist));
00179 }
00180
00181 public void compare() throws Exception {
00182 for(int i = 0; i < referenceDist.values.length; i++) {
00183 String varName = referenceDist.getVariableName(i);
00184 int i2 = otherDist.getVariableIndex(varName);
00185 if(i2 < 0)
00186 throw new Exception("Variable " + referenceDist.getVariableName(i) + " has no correspondence in second distribution");
00187 for(int j = 0; j < referenceDist.values[i].length; j++) {
00188 double v1 = referenceDist.getProbability(i, j);
00189 double v2 = otherDist.getProbability(i2, j);
00190 for(DistributionEntryComparison p : processors)
00191 p.process(i, j, v1, v2);
00192 }
00193 }
00194 }
00195
00196 public void printResults() {
00197 for(DistributionEntryComparison dec : processors)
00198 dec.printResult();
00199 }
00200
00201 public double getResult(Class<? extends DistributionEntryComparison> c) throws Exception {
00202 for(DistributionEntryComparison p : processors)
00203 if(c.isInstance(p)) {
00204 return p.getResult();
00205 }
00206 throw new Exception(c.getSimpleName() + " was not processed in this comparison");
00207 }
00208 }
00209
00210 public static abstract class DistributionEntryComparison {
00211 BasicSampledDistribution refDist;
00212 public DistributionEntryComparison(BasicSampledDistribution refDist) {
00213 this.refDist = refDist;
00214 }
00215 public abstract void process(int varIdx, int domIdx, double p1, double p2);
00216 public abstract double getResult();
00217 public void printResult() {
00218 System.out.printf("%s = %s\n", getClass().getSimpleName(), getResult());
00219 }
00220 }
00221
00222 public static class MeanSquaredError extends DistributionEntryComparison {
00223 double sum = 0.0; int cnt = 0;
00224 public MeanSquaredError(BasicSampledDistribution refDist) {
00225 super(refDist);
00226 }
00227 @Override
00228 public void process(int varIdx, int domIdx, double p1, double p2) {
00229 ++cnt;
00230 double error = p1-p2;
00231 error *= error;
00232 sum += error;
00233 }
00234 @Override
00235 public double getResult() {
00236 return sum/cnt;
00237 }
00238 }
00239
00240 public static class HellingerDistance extends DistributionEntryComparison {
00241 double BhattacharyyaCoefficient = 0.0;
00242 double sum = 0.0;
00243 int prevVarIdx = -1;
00244 int numVars = 0;
00245 public HellingerDistance(BasicSampledDistribution refDist) {
00246 super(refDist);
00247 }
00248 @Override
00249 public void process(int varIdx, int domIdx, double p1, double p2) {
00250 if(prevVarIdx != varIdx) {
00251
00252
00253 prevVarIdx = varIdx;
00254 numVars++;
00255 sum += BhattacharyyaCoefficient;
00256 BhattacharyyaCoefficient = 0;
00257 }
00258 BhattacharyyaCoefficient += Math.sqrt(p1*p2);
00259 }
00260 @Override
00261 public double getResult() {
00262 sum += BhattacharyyaCoefficient;
00263 sum /= numVars;
00264 return Math.sqrt(1-sum);
00265 }
00266 }
00267
00268 public static class ErrorList extends DistributionEntryComparison {
00269 public ErrorList(BasicSampledDistribution refDist) {
00270 super(refDist);
00271 }
00272 @Override
00273 public void process(int varIdx, int domIdx, double p1, double p2) {
00274 double error = p1 - p2;
00275 if(error != 0.0) {
00276 System.out.printf("%s=%s: %f %f -> %f\n", refDist.getVariableName(varIdx), refDist.getDomain(varIdx)[domIdx], p1, p2, error);
00277 }
00278 }
00279 @Override
00280 public double getResult() {
00281 return 0;
00282 }
00283 @Override
00284 public void printResult() {}
00285 }
00286 }