00001 #ifndef TOON_DOWNHILL_SIMPLEX_H
00002 #define TOON_DOWNHILL_SIMPLEX_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <algorithm>
00006 #include <cstdlib>
00007
00008 namespace TooN
00009 {
00010
00077 template<int N=-1, typename Precision=double> class DownhillSimplex
00078 {
00079 static const int Vertices = (N==-1?-1:N+1);
00080 typedef Matrix<Vertices, N, Precision> Simplex;
00081 typedef Vector<Vertices, Precision> Values;
00082
00083 public:
00092 template<class Function> DownhillSimplex(const Function& func, const Vector<N>& c, Precision spread=1)
00093 :simplex(c.size()+1, c.size()),values(c.size()+1)
00094 {
00095 alpha = 1.0;
00096 rho = 2.0;
00097 gamma = 0.5;
00098 sigma = 0.5;
00099
00100 epsilon = sqrt(numeric_limits<Precision>::epsilon());
00101 zero_epsilon = 1e-20;
00102
00103 restart(func, c, spread);
00104 }
00105
00112 template<class Function> void restart(const Function& func, const Vector<N>& c, Precision spread)
00113 {
00114 for(int i=0; i < simplex.num_rows(); i++)
00115 simplex[i] = c;
00116
00117 for(int i=0; i < simplex.num_cols(); i++)
00118 simplex[i][i] += spread;
00119
00120 for(int i=0; i < values.size(); i++)
00121 values[i] = func(simplex[i]);
00122 }
00123
00129 bool finished()
00130 {
00131 Precision span = norm(simplex[get_best()] - simplex[get_worst()]);
00132 Precision scale = norm(simplex[get_best()]);
00133
00134 if(span/scale < epsilon || span < zero_epsilon)
00135 return 1;
00136 else
00137 return 0;
00138 }
00139
00144 template<class Function> void restart(const Function& func, Precision spread)
00145 {
00146 restart(func, simplex[get_best()], spread);
00147 }
00148
00150 const Simplex& get_simplex() const
00151 {
00152 return simplex;
00153 }
00154
00156 const Values& get_values() const
00157 {
00158 return values;
00159 }
00160
00162 int get_best() const
00163 {
00164 return std::min_element(&values[0], &values[0] + values.size()) - &values[0];
00165 }
00166
00168 int get_worst() const
00169 {
00170 return std::max_element(&values[0], &values[0] + values.size()) - &values[0];
00171 }
00172
00175 template<class Function> void find_next_point(const Function& func)
00176 {
00177
00178
00179
00180
00181
00182 int worst = get_worst();
00183 Precision second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
00184 int best=0;
00185 Vector<N> x0 = Zeros(simplex.num_cols());
00186
00187
00188 for(int i=0; i < simplex.num_rows(); i++)
00189 {
00190 if(values[i] < bestval)
00191 {
00192 bestval = values[i];
00193 best = i;
00194 }
00195
00196 if(i != worst)
00197 {
00198 if(values[i] > second_worst_val)
00199 second_worst_val = values[i];
00200
00201
00202 x0 += simplex[i];
00203 }
00204 }
00205 x0 *= 1.0 / simplex.num_cols();
00206
00207
00208
00209 Vector<N> xr = (1 + alpha) * x0 - alpha * simplex[worst];
00210 Precision fr = func(xr);
00211
00212 if(fr < bestval)
00213 {
00214
00215 Vector<N> xe = rho * xr + (1-rho) * x0;
00216 Precision fe = func(xe);
00217
00218
00219 if(fe < fr)
00220 {
00221 simplex[worst] = xe;
00222 values[worst] = fe;
00223 }
00224 else
00225 {
00226 simplex[worst] = xr;
00227 values[worst] = fr;
00228 }
00229
00230 return;
00231 }
00232
00233
00234
00235 if(fr < second_worst_val)
00236 {
00237 simplex[worst] = xr;
00238 values[worst] = fr;
00239 return;
00240 }
00241
00242
00243
00244
00245
00246 if(fr < worst_val)
00247 {
00248 Vector<N> xc = (1 + gamma) * x0 - gamma * simplex[worst];
00249 Precision fc = func(xc);
00250
00251
00252 if(fc <= fr)
00253 {
00254 simplex[worst] = xc;
00255 values[worst] = fc;
00256 return;
00257 }
00258 }
00259
00260
00261
00262 for(int i=0; i < simplex.num_rows(); i++)
00263 if(i != best)
00264 {
00265 simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
00266 values[i] = func(simplex[i]);
00267 }
00268 }
00269
00273 template<class Function> bool iterate(const Function& func)
00274 {
00275 find_next_point(func);
00276 return !finished();
00277 }
00278
00279 Precision alpha;
00280 Precision rho;
00281 Precision gamma;
00282 Precision sigma;
00283 Precision epsilon;
00284 Precision zero_epsilon;
00285
00286 private:
00287
00288
00289 Simplex simplex;
00290
00291
00292 Values values;
00293
00294
00295 };
00296 }
00297 #endif