TooN 2.1
|
00001 #include <TooN/optimization/brent.h> 00002 #include <utility> 00003 #include <cmath> 00004 #include <cassert> 00005 #include <cstdlib> 00006 00007 namespace TooN{ 00008 namespace Internal{ 00009 00010 00011 ///Turn a multidimensional function in to a 1D function by specifying a 00012 ///point and direction. A nre function is defined: 00013 ////\f[ 00014 /// g(a) = \Vec{s} + a \Vec{d} 00015 ///\f] 00016 ///@ingroup gOptimize 00017 template<int Size, typename Precision, typename Func> struct LineSearch 00018 { 00019 const Vector<Size, Precision>& start; ///< \f$\Vec{s}\f$ 00020 const Vector<Size, Precision>& direction;///< \f$\Vec{d}\f$ 00021 00022 const Func& f;///< \f$f(\cdotp)\f$ 00023 00024 ///Set up the line search class. 00025 ///@param s Start point, \f$\Vec{s}\f$. 00026 ///@param d direction, \f$\Vec{d}\f$. 00027 ///@param func Function, \f$f(\cdotp)\f$. 00028 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func) 00029 :start(s),direction(d),f(func) 00030 {} 00031 00032 ///@param x Position to evaluate function 00033 ///@return \f$f(\vec{s} + x\vec{d})\f$ 00034 Precision operator()(Precision x) const 00035 { 00036 return f(start + x * direction); 00037 } 00038 }; 00039 00040 ///Bracket a 1D function by searching forward from zero. The assumption 00041 ///is that a minima exists in \f$f(x),\ x>0\f$, and this function searches 00042 ///for a bracket using exponentially growning or shrinking steps. 00043 ///@param a_val The value of the function at zero. 00044 ///@param func Function to bracket 00045 ///@param initial_lambda Initial stepsize 00046 ///@param zeps Minimum bracket size. 00047 ///@return <code>m[i][0]</code> contains the values of \f$x\f$ for the bracket, in increasing order, 00048 /// and <code>m[i][1]</code> contains the corresponding values of \f$f(x)\f$. If the bracket 00049 /// drops below the minimum bracket size, all zeros are returned. 00050 ///@ingroup gOptimize 00051 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps) 00052 { 00053 //Get a, b, c to bracket a minimum along a line 00054 Precision a, b, c, b_val, c_val; 00055 00056 a=0; 00057 00058 //Search forward in steps of lambda 00059 Precision lambda=initial_lambda; 00060 b = lambda; 00061 b_val = func(b); 00062 00063 while(std::isnan(b_val)) 00064 { 00065 //We've probably gone in to an invalid region. This can happen even 00066 //if following the gradient would never get us there. 00067 //try backing off lambda 00068 lambda*=.5; 00069 b = lambda; 00070 b_val = func(b); 00071 00072 } 00073 00074 00075 if(b_val < a_val) //We've gone downhill, so keep searching until we go back up 00076 { 00077 double last_good_lambda = lambda; 00078 00079 for(;;) 00080 { 00081 lambda *= 2; 00082 c = lambda; 00083 c_val = func(c); 00084 00085 if(std::isnan(c_val)) 00086 break; 00087 last_good_lambda = lambda; 00088 if(c_val > b_val) // we have a bracket 00089 break; 00090 else 00091 { 00092 a = b; 00093 a_val = b_val; 00094 b=c; 00095 b_val=c_val; 00096 00097 } 00098 } 00099 00100 //We took a step too far. 00101 //Back up: this will not attempt to ensure a bracket 00102 if(std::isnan(c_val)) 00103 { 00104 double bad_lambda=lambda; 00105 double l=1; 00106 00107 for(;;) 00108 { 00109 l*=.5; 00110 c = last_good_lambda + (bad_lambda - last_good_lambda)*l; 00111 c_val = func(c); 00112 00113 if(!std::isnan(c_val)) 00114 break; 00115 } 00116 00117 00118 } 00119 00120 } 00121 else //We've overshot the minimum, so back up 00122 { 00123 c = b; 00124 c_val = b_val; 00125 //Here, c_val > a_val 00126 00127 for(;;) 00128 { 00129 lambda *= .5; 00130 b = lambda; 00131 b_val = func(b); 00132 00133 if(b_val < a_val)// we have a bracket 00134 break; 00135 else if(lambda < zeps) 00136 return Zeros; 00137 else //Contract the bracket 00138 { 00139 c = b; 00140 c_val = b_val; 00141 } 00142 } 00143 } 00144 00145 Matrix<3,2> ret; 00146 ret[0] = makeVector(a, a_val); 00147 ret[1] = makeVector(b, b_val); 00148 ret[2] = makeVector(c, c_val); 00149 00150 return ret; 00151 } 00152 00153 } 00154 00155 00156 /** This class provides a nonlinear conjugate-gradient optimizer. The following 00157 code snippet will perform an optimization on the Rosenbrock Bananna function in 00158 two dimensions: 00159 00160 @code 00161 double Rosenbrock(const Vector<2>& v) 00162 { 00163 return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0])); 00164 } 00165 00166 Vector<2> RosenbrockDerivatives(const Vector<2>& v) 00167 { 00168 double x = v[0]; 00169 double y = v[1]; 00170 00171 Vector<2> ret; 00172 ret[0] = -2+2*x-400*(y-sq(x))*x; 00173 ret[1] = 200*y-200*sq(x); 00174 00175 return ret; 00176 } 00177 00178 int main() 00179 { 00180 ConjugateGradient<2> cg(makeVector(0,0), Rosenbrock, RosenbrockDerivatives); 00181 00182 while(cg.iterate(Rosenbrock, RosenbrockDerivatives)) 00183 cout << "y_" << iteration << " = " << cg.y << endl; 00184 00185 cout << "Optimal value: " << cg.y << endl; 00186 } 00187 @endcode 00188 00189 The chances are that you will want to read the documentation for 00190 ConjugateGradient::ConjugateGradient and ConjugateGradient::iterate. 00191 00192 Linesearch is currently performed using golden-section search and conjugate 00193 vector updates are performed using the Polak-Ribiere equations. There many 00194 tunable parameters, and the internals are readily accessible, so alternative 00195 termination conditions etc can easily be substituted. However, ususally these 00196 will not be necessary. 00197 00198 @ingroup gOptimize 00199 */ 00200 template<int Size=Dynamic, class Precision=double> struct ConjugateGradient 00201 { 00202 const int size; ///< Dimensionality of the space. 00203 Vector<Size> g; ///< Gradient vector used by the next call to iterate() 00204 Vector<Size> h; ///< Conjugate vector to be searched along in the next call to iterate() 00205 Vector<Size> minus_h;///< negative of h as this is required to be passed into a function which uses references (so can't be temporary) 00206 Vector<Size> old_g; ///< Gradient vector used to compute $h$ in the last call to iterate() 00207 Vector<Size> old_h; ///< Conjugate vector searched along in the last call to iterate() 00208 Vector<Size> x; ///< Current position (best known point) 00209 Vector<Size> old_x; ///< Previous best known point (not set at construction) 00210 Precision y; ///< Function at \f$x\f$ 00211 Precision old_y; ///< Function at old_x 00212 00213 Precision tolerance; ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision. 00214 Precision epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00215 int max_iterations; ///< Maximum number of iterations. Defaults to \c size\f$*100\f$ 00216 00217 Precision bracket_initial_lambda;///< Initial stepsize used in bracketing the minimum for the line search. Defaults to 1. 00218 Precision linesearch_tolerance; ///< Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precision. 00219 Precision linesearch_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00220 int linesearch_max_iterations; ///< Maximum number of iterations in the linesearch. Defaults to 100. 00221 00222 Precision bracket_epsilon; ///<Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged. Defaults to 1e-20. 00223 00224 int iterations; ///< Number of iterations performed 00225 00226 ///Initialize the ConjugateGradient class with sensible values. 00227 ///@param start Starting point, \e x 00228 ///@param func Function \e f to compute \f$f(x)\f$ 00229 ///@param deriv Function to compute \f$\nabla f(x)\f$ 00230 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv) 00231 : size(start.size()), 00232 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size) 00233 { 00234 init(start, func(start), deriv(start)); 00235 } 00236 00237 ///Initialize the ConjugateGradient class with sensible values. 00238 ///@param start Starting point, \e x 00239 ///@param func Function \e f to compute \f$f(x)\f$ 00240 ///@param deriv \f$\nabla f(x)\f$ 00241 template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv) 00242 : size(start.size()), 00243 g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size) 00244 { 00245 init(start, func(start), deriv); 00246 } 00247 00248 ///Initialize the ConjugateGradient class with sensible values. Used internally. 00249 ///@param start Starting point, \e x 00250 ///@param func \f$f(x)\f$ 00251 ///@param deriv \f$\nabla f(x)\f$ 00252 void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv) 00253 { 00254 00255 using std::numeric_limits; 00256 using std::sqrt; 00257 x = start; 00258 00259 //Start with the conjugate direction aligned with 00260 //the gradient 00261 g = deriv; 00262 h = g; 00263 minus_h=-h; 00264 00265 y = func; 00266 old_y = y; 00267 00268 tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00269 epsilon = 1e-20; 00270 max_iterations = size * 100; 00271 00272 bracket_initial_lambda = 1; 00273 00274 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00275 linesearch_epsilon = 1e-20; 00276 linesearch_max_iterations=100; 00277 00278 bracket_epsilon=1e-20; 00279 00280 iterations=0; 00281 } 00282 00283 00284 ///Perform a linesearch from the current point (x) along the current 00285 ///conjugate vector (h). The linesearch does not make use of derivatives. 00286 ///You probably do not want to use this function. See iterate() instead. 00287 ///This function updates: 00288 /// - x 00289 /// - old_c 00290 /// - y 00291 /// - old_y 00292 /// - iterations 00293 /// Note that the conjugate direction and gradient are not updated. 00294 /// If bracket_minimum_forward detects a local maximum, then essentially a zero 00295 /// sized step is taken. 00296 /// @param func Functor returning the function value at a given point. 00297 template<class Func> void find_next_point(const Func& func) 00298 { 00299 Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func); 00300 00301 //Always search in the conjugate direction (h) 00302 //First bracket a minimum. 00303 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon); 00304 00305 double a = bracket[0][0]; 00306 double b = bracket[1][0]; 00307 double c = bracket[2][0]; 00308 00309 double a_val = bracket[0][1]; 00310 double b_val = bracket[1][1]; 00311 double c_val = bracket[2][1]; 00312 00313 old_y = y; 00314 old_x = x; 00315 iterations++; 00316 00317 //Local maximum achieved! 00318 if(a==0 && b== 0 && c == 0) 00319 return; 00320 00321 //We should have a bracket here 00322 00323 if(c < b) 00324 { 00325 //Failed to bracket due to NaN, so c is the best known point. 00326 //Simply go there. 00327 x-=h * c; 00328 y=c_val; 00329 00330 } 00331 else 00332 { 00333 assert(a < b && b < c); 00334 assert(a_val > b_val && b_val < c_val); 00335 00336 //Find the real minimum 00337 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon); 00338 00339 assert(m[0] >= a && m[0] <= c); 00340 assert(m[1] <= b_val); 00341 00342 //Update the current position and value 00343 x -= m[0] * h; 00344 y = m[1]; 00345 } 00346 } 00347 00348 ///Check to see it iteration should stop. You probably do not want to use 00349 ///this function. See iterate() instead. This function updates nothing. 00350 bool finished() 00351 { 00352 using std::abs; 00353 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon); 00354 } 00355 00356 ///After an iteration, update the gradient and conjugate using the 00357 ///Polak-Ribiere equations. 00358 ///This function updates: 00359 ///- g 00360 ///- old_g 00361 ///- h 00362 ///- old_h 00363 ///@param grad The derivatives of the function at \e x 00364 void update_vectors_PR(const Vector<Size>& grad) 00365 { 00366 //Update the position, gradient and conjugate directions 00367 old_g = g; 00368 old_h = h; 00369 00370 g = grad; 00371 //Precision gamma = (g * g - oldg*g)/(oldg * oldg); 00372 Precision gamma = (g * g - old_g*g)/(old_g * old_g); 00373 h = g + gamma * old_h; 00374 minus_h=-h; 00375 } 00376 00377 ///Use this function to iterate over the optimization. Note that after 00378 ///iterate returns false, g, h, old_g and old_h will not have been 00379 ///updated. 00380 ///This function updates: 00381 /// - x 00382 /// - old_c 00383 /// - y 00384 /// - old_y 00385 /// - iterations 00386 /// - g* 00387 /// - old_g* 00388 /// - h* 00389 /// - old_h* 00390 /// *'d variables not updated on the last iteration. 00391 ///@param func Functor returning the function value at a given point. 00392 ///@param deriv Functor to compute derivatives at the specified point. 00393 ///@return Whether to continue. 00394 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv) 00395 { 00396 find_next_point(func); 00397 00398 if(!finished()) 00399 { 00400 update_vectors_PR(deriv(x)); 00401 return 1; 00402 } 00403 else 00404 return 0; 00405 } 00406 }; 00407 00408 }