00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006
00007 namespace TooN{
00008 namespace Internal{
00009
00010
00017 template<int Size, typename Precision, typename Func> struct LineSearch
00018 {
00019 const Vector<Size, Precision>& start;
00020 const Vector<Size, Precision>& direction;
00021
00022 const Func& f;
00023
00028 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00029 :start(s),direction(d),f(func)
00030 {}
00031
00034 Precision operator()(Precision x) const
00035 {
00036 return f(start + x * direction);
00037 }
00038 };
00039
00051 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
00052 {
00053
00054 Precision a, b, c, b_val, c_val;
00055
00056 a=0;
00057
00058
00059 Precision lambda=initial_lambda;
00060 b = lambda;
00061 b_val = func(b);
00062
00063 while(std::isnan(b_val))
00064 {
00065
00066
00067
00068 lambda*=.5;
00069 b = lambda;
00070 b_val = func(b);
00071
00072 }
00073
00074
00075 if(b_val < a_val)
00076 {
00077 double last_good_lambda = lambda;
00078
00079 for(;;)
00080 {
00081 lambda *= 2;
00082 c = lambda;
00083 c_val = func(c);
00084
00085 if(std::isnan(c_val))
00086 break;
00087 last_good_lambda = lambda;
00088 if(c_val > b_val)
00089 break;
00090 else
00091 {
00092 a = b;
00093 a_val = b_val;
00094 b=c;
00095 b_val=c_val;
00096
00097 }
00098 }
00099
00100
00101
00102 if(std::isnan(c_val))
00103 {
00104 double bad_lambda=lambda;
00105 double l=1;
00106
00107 for(;;)
00108 {
00109 l*=.5;
00110 c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
00111 c_val = func(c);
00112
00113 if(!std::isnan(c_val))
00114 break;
00115 }
00116
00117
00118 }
00119
00120 }
00121 else
00122 {
00123 c = b;
00124 c_val = b_val;
00125
00126
00127 for(;;)
00128 {
00129 lambda *= .5;
00130 b = lambda;
00131 b_val = func(b);
00132
00133 if(b_val < a_val)
00134 break;
00135 else if(lambda < zeps)
00136 return Zeros;
00137 else
00138 {
00139 c = b;
00140 c_val = b_val;
00141 }
00142 }
00143 }
00144
00145 Matrix<3,2> ret;
00146 ret[0] = makeVector(a, a_val);
00147 ret[1] = makeVector(b, b_val);
00148 ret[2] = makeVector(c, c_val);
00149
00150 return ret;
00151 }
00152
00153 }
00154
00155
00200 template<int Size, class Precision=double> struct ConjugateGradient
00201 {
00202 const int size;
00203 Vector<Size> g;
00204 Vector<Size> h;
00205 Vector<Size> minus_h;
00206 Vector<Size> old_g;
00207 Vector<Size> old_h;
00208 Vector<Size> x;
00209 Vector<Size> old_x;
00210 Precision y;
00211 Precision old_y;
00212
00213 Precision tolerance;
00214 Precision epsilon;
00215 int max_iterations;
00216
00217 Precision bracket_initial_lambda;
00218 Precision linesearch_tolerance;
00219 Precision linesearch_epsilon;
00220 int linesearch_max_iterations;
00221
00222 Precision bracket_epsilon;
00223
00224 int iterations;
00225
00230 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00231 : size(start.size()),
00232 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00233 {
00234 init(start, func(start), deriv(start));
00235 }
00236
00241 template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
00242 : size(start.size()),
00243 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00244 {
00245 init(start, func(start), deriv);
00246 }
00247
00252 void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
00253 {
00254
00255 using std::numeric_limits;
00256 x = start;
00257
00258
00259
00260 g = deriv;
00261 h = g;
00262 minus_h=-h;
00263
00264 y = func;
00265 old_y = y;
00266
00267 tolerance = sqrt(numeric_limits<Precision>::epsilon());
00268 epsilon = 1e-20;
00269 max_iterations = size * 100;
00270
00271 bracket_initial_lambda = 1;
00272
00273 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon());
00274 linesearch_epsilon = 1e-20;
00275 linesearch_max_iterations=100;
00276
00277 bracket_epsilon=1e-20;
00278
00279 iterations=0;
00280 }
00281
00282
00296 template<class Func> void find_next_point(const Func& func)
00297 {
00298 Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
00299
00300
00301
00302 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
00303
00304 double a = bracket[0][0];
00305 double b = bracket[1][0];
00306 double c = bracket[2][0];
00307
00308 double a_val = bracket[0][1];
00309 double b_val = bracket[1][1];
00310 double c_val = bracket[2][1];
00311
00312 old_y = y;
00313 old_x = x;
00314 iterations++;
00315
00316
00317 if(a==0 && b== 0 && c == 0)
00318 return;
00319
00320
00321
00322 if(c < b)
00323 {
00324
00325
00326 x-=h * c;
00327 y=c_val;
00328
00329 }
00330 else
00331 {
00332 assert(a < b && b < c);
00333 assert(a_val > b_val && b_val < c_val);
00334
00335
00336 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
00337
00338 assert(m[0] >= a && m[0] <= c);
00339 assert(m[1] <= b_val);
00340
00341
00342 x -= m[0] * h;
00343 y = m[1];
00344 }
00345 }
00346
00349 bool finished()
00350 {
00351 using std::abs;
00352 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00353 }
00354
00363 void update_vectors_PR(const Vector<Size>& grad)
00364 {
00365
00366 old_g = g;
00367 old_h = h;
00368
00369 g = grad;
00370
00371 Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00372 h = g + gamma * old_h;
00373 minus_h=-h;
00374 }
00375
00393 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00394 {
00395 find_next_point(func);
00396
00397 if(!finished())
00398 {
00399 update_vectors_PR(deriv(x));
00400 return 1;
00401 }
00402 else
00403 return 0;
00404 }
00405 };
00406
00407 }