00001 #ifndef SCITBX_LINE_SEARCH_MORE_THUENTE_1994_RAW_H
00002 #define SCITBX_LINE_SEARCH_MORE_THUENTE_1994_RAW_H
00003
00004 #include <vector>
00005 #include <stdexcept>
00006 #include <cmath>
00007 #include <cstddef>
00008
00009 namespace scitbx { namespace line_search {
00010
00011 template <typename FloatType=double>
00012 class mcsrch
00013 {
00014 protected:
00015 int infoc;
00016 FloatType dginit;
00017 bool brackt;
00018 bool stage1;
00019 FloatType finit;
00020 FloatType dgtest;
00021 FloatType width;
00022 FloatType width1;
00023 FloatType stx;
00024 FloatType fx;
00025 FloatType dgx;
00026 FloatType sty;
00027 FloatType fy;
00028 FloatType dgy;
00029 FloatType stmin;
00030 FloatType stmax;
00031 std::vector<FloatType> initial_x;
00032
00033 static
00034 FloatType
00035 pow2(FloatType const& x) { return x * x; }
00036
00037 static
00038 FloatType const&
00039 max3(
00040 FloatType const& x,
00041 FloatType const& y,
00042 FloatType const& z)
00043 {
00044 return x < y ? (y < z ? z : y ) : (x < z ? z : x );
00045 }
00046
00047 public:
00048 const char* info_meaning;
00049
00050 void
00051 free_workspace() { initial_x = std::vector<FloatType>(); }
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144 void
00145 run(
00146 FloatType const& gtol,
00147 FloatType const& stpmin,
00148 FloatType const& stpmax,
00149 unsigned n,
00150 FloatType* x,
00151 FloatType f,
00152 const FloatType* g,
00153 const FloatType* s,
00154 FloatType& stp,
00155 FloatType ftol,
00156 FloatType xtol,
00157 unsigned maxfev,
00158 int& info,
00159 unsigned& nfev);
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220 static
00221 int
00222 mcstep(
00223 FloatType& stx,
00224 FloatType& fx,
00225 FloatType& dx,
00226 FloatType& sty,
00227 FloatType& fy,
00228 FloatType& dy,
00229 FloatType& stp,
00230 FloatType fp,
00231 FloatType dp,
00232 bool& brackt,
00233 FloatType stpmin,
00234 FloatType stpmax);
00235 };
00236
00237 template <typename FloatType>
00238 void mcsrch<FloatType>::run(
00239 FloatType const& gtol,
00240 FloatType const& stpmin,
00241 FloatType const& stpmax,
00242 unsigned n,
00243 FloatType* x,
00244 FloatType f,
00245 const FloatType* g,
00246 const FloatType* s,
00247 FloatType& stp,
00248 FloatType ftol,
00249 FloatType xtol,
00250 unsigned maxfev,
00251 int& info,
00252 unsigned& nfev)
00253 {
00254 if (info != -1) {
00255 infoc = 1;
00256 if ( n == 0
00257 || maxfev == 0
00258 || gtol < FloatType(0)
00259 || xtol < FloatType(0)
00260 || stpmin < FloatType(0)
00261 || stpmax < stpmin) {
00262 throw std::runtime_error("Improper input parameters.");
00263 }
00264 if (stp <= FloatType(0) || ftol < FloatType(0)) {
00265 throw std::runtime_error("Improper value for stp or ftol.");
00266 }
00267
00268
00269 dginit = FloatType(0);
00270 for (unsigned j = 0; j < n; j++) {
00271 dginit += g[j] * s[j];
00272 }
00273 if (dginit >= FloatType(0)) {
00274 throw std::runtime_error("Search direction not descent.");
00275 }
00276 brackt = false;
00277 stage1 = true;
00278 nfev = 0;
00279 finit = f;
00280 dgtest = ftol*dginit;
00281 width = stpmax - stpmin;
00282 width1 = FloatType(2) * width;
00283 initial_x.assign(x, x+n);
00284
00285
00286
00287
00288
00289
00290
00291 stx = FloatType(0);
00292 fx = finit;
00293 dgx = dginit;
00294 sty = FloatType(0);
00295 fy = finit;
00296 dgy = dginit;
00297 }
00298 for (;;) {
00299 if (info != -1) {
00300
00301
00302 if (brackt) {
00303 stmin = std::min(stx, sty);
00304 stmax = std::max(stx, sty);
00305 }
00306 else {
00307 stmin = stx;
00308 stmax = stp + FloatType(4) * (stp - stx);
00309 }
00310
00311 stp = std::max(stp, stpmin);
00312 stp = std::min(stp, stpmax);
00313
00314
00315 if ( (brackt && (stp <= stmin || stp >= stmax))
00316 || nfev >= maxfev - 1 || infoc == 0
00317 || (brackt && stmax - stmin <= xtol * stmax)) {
00318 stp = stx;
00319 }
00320
00321
00322
00323 for (unsigned j = 0; j < n; j++) {
00324 x[j] = initial_x[j] + stp * s[j];
00325 }
00326 info = -1;
00327 info_meaning =
00328 "A return is made to compute the function and gradient.";
00329 break;
00330 }
00331 info = 0;
00332 info_meaning = 0;
00333 nfev++;
00334 FloatType dg(0);
00335 for (unsigned j = 0; j < n; j++) {
00336 dg += g[j] * s[j];
00337 }
00338 FloatType ftest1 = finit + stp*dgtest;
00339
00340 if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) {
00341 info = 6;
00342 info_meaning =
00343 "Rounding errors prevent further progress."
00344 " There may not be a step which satisfies the"
00345 " sufficient decrease and curvature conditions."
00346 " Tolerances may be too small.";
00347 break;
00348 }
00349 if (stp == stpmax && f <= ftest1 && dg <= dgtest) {
00350 info = 5;
00351 info_meaning = "The step is at the upper bound stpmax.";
00352 break;
00353 }
00354 if (stp == stpmin && (f > ftest1 || dg >= dgtest)) {
00355 info = 4;
00356 info_meaning = "The step is at the lower bound stpmin.";
00357 break;
00358 }
00359 if (nfev >= maxfev) {
00360 info = 3;
00361 info_meaning = "Number of function evaluations has reached maxfev.";
00362 break;
00363 }
00364 if (brackt && stmax - stmin <= xtol * stmax) {
00365 info = 2;
00366 info_meaning =
00367 "Relative width of the interval of uncertainty"
00368 " is at most xtol.";
00369 break;
00370 }
00371
00372 if (f <= ftest1 && std::abs(dg) <= gtol * (-dginit)) {
00373 info = 1;
00374 info_meaning =
00375 "The sufficient decrease condition and the"
00376 " directional derivative condition hold.";
00377 break;
00378 }
00379
00380
00381 if ( stage1 && f <= ftest1
00382 && dg >= std::min(ftol, gtol) * dginit) {
00383 stage1 = false;
00384 }
00385
00386
00387
00388
00389
00390 if (stage1 && f <= fx && f > ftest1) {
00391
00392 FloatType fm = f - stp*dgtest;
00393 FloatType fxm = fx - stx*dgtest;
00394 FloatType fym = fy - sty*dgtest;
00395 FloatType dgm = dg - dgtest;
00396 FloatType dgxm = dgx - dgtest;
00397 FloatType dgym = dgy - dgtest;
00398
00399
00400 infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm,
00401 brackt, stmin, stmax);
00402
00403 fx = fxm + stx*dgtest;
00404 fy = fym + sty*dgtest;
00405 dgx = dgxm + dgtest;
00406 dgy = dgym + dgtest;
00407 }
00408 else {
00409
00410
00411 infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg,
00412 brackt, stmin, stmax);
00413 }
00414
00415
00416 if (brackt) {
00417 if (std::abs(sty - stx) >= FloatType(0.66) * width1) {
00418 stp = stx + FloatType(0.5) * (sty - stx);
00419 }
00420 width1 = width;
00421 width = std::abs(sty - stx);
00422 }
00423 }
00424 }
00425
00426 template <typename FloatType>
00427 int mcsrch<FloatType>::mcstep(
00428 FloatType& stx,
00429 FloatType& fx,
00430 FloatType& dx,
00431 FloatType& sty,
00432 FloatType& fy,
00433 FloatType& dy,
00434 FloatType& stp,
00435 FloatType fp,
00436 FloatType dp,
00437 bool& brackt,
00438 FloatType stpmin,
00439 FloatType stpmax)
00440 {
00441 bool bound;
00442 FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta;
00443 int info = 0;
00444 if ( ( brackt && (stp <= std::min(stx, sty)
00445 || stp >= std::max(stx, sty)))
00446 || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) {
00447 return 0;
00448 }
00449
00450 sgnd = dp * (dx / std::abs(dx));
00451 if (fp > fx) {
00452
00453
00454
00455
00456 info = 1;
00457 bound = true;
00458 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00459 s = max3(std::abs(theta), std::abs(dx), std::abs(dp));
00460 gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
00461 if (stp < stx) gamma = - gamma;
00462 p = (gamma - dx) + theta;
00463 q = ((gamma - dx) + gamma) + dp;
00464 r = p/q;
00465 stpc = stx + r * (stp - stx);
00466 stpq = stx
00467 + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2))
00468 * (stp - stx);
00469 if (std::abs(stpc - stx) < std::abs(stpq - stx)) {
00470 stpf = stpc;
00471 }
00472 else {
00473 stpf = stpc + (stpq - stpc) / FloatType(2);
00474 }
00475 brackt = true;
00476 }
00477 else if (sgnd < FloatType(0)) {
00478
00479
00480
00481
00482 info = 2;
00483 bound = false;
00484 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00485 s = max3(std::abs(theta), std::abs(dx), std::abs(dp));
00486 gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
00487 if (stp > stx) gamma = - gamma;
00488 p = (gamma - dp) + theta;
00489 q = ((gamma - dp) + gamma) + dx;
00490 r = p/q;
00491 stpc = stp + r * (stx - stp);
00492 stpq = stp + (dp / (dp - dx)) * (stx - stp);
00493 if (std::abs(stpc - stp) > std::abs(stpq - stp)) {
00494 stpf = stpc;
00495 }
00496 else {
00497 stpf = stpq;
00498 }
00499 brackt = true;
00500 }
00501 else if (std::abs(dp) < std::abs(dx)) {
00502
00503
00504
00505
00506
00507
00508
00509
00510 info = 3;
00511 bound = true;
00512 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00513 s = max3(std::abs(theta), std::abs(dx), std::abs(dp));
00514 gamma = s * std::sqrt(
00515 std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s)));
00516 if (stp > stx) gamma = -gamma;
00517 p = (gamma - dp) + theta;
00518 q = (gamma + (dx - dp)) + gamma;
00519 r = p/q;
00520 if (r < FloatType(0) && gamma != FloatType(0)) {
00521 stpc = stp + r * (stx - stp);
00522 }
00523 else if (stp > stx) {
00524 stpc = stpmax;
00525 }
00526 else {
00527 stpc = stpmin;
00528 }
00529 stpq = stp + (dp / (dp - dx)) * (stx - stp);
00530 if (brackt) {
00531 if (std::abs(stp - stpc) < std::abs(stp - stpq)) {
00532 stpf = stpc;
00533 }
00534 else {
00535 stpf = stpq;
00536 }
00537 }
00538 else {
00539 if (std::abs(stp - stpc) > std::abs(stp - stpq)) {
00540 stpf = stpc;
00541 }
00542 else {
00543 stpf = stpq;
00544 }
00545 }
00546 }
00547 else {
00548
00549
00550
00551
00552 info = 4;
00553 bound = false;
00554 if (brackt) {
00555 theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp;
00556 s = max3(std::abs(theta), std::abs(dy), std::abs(dp));
00557 gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s));
00558 if (stp > sty) gamma = -gamma;
00559 p = (gamma - dp) + theta;
00560 q = ((gamma - dp) + gamma) + dy;
00561 r = p/q;
00562 stpc = stp + r * (sty - stp);
00563 stpf = stpc;
00564 }
00565 else if (stp > stx) {
00566 stpf = stpmax;
00567 }
00568 else {
00569 stpf = stpmin;
00570 }
00571 }
00572
00573
00574 if (fp > fx) {
00575 sty = stp;
00576 fy = fp;
00577 dy = dp;
00578 }
00579 else {
00580 if (sgnd < FloatType(0)) {
00581 sty = stx;
00582 fy = fx;
00583 dy = dx;
00584 }
00585 stx = stp;
00586 fx = fp;
00587 dx = dp;
00588 }
00589
00590 stpf = std::min(stpmax, stpf);
00591 stpf = std::max(stpmin, stpf);
00592 stp = stpf;
00593 if (brackt && bound) {
00594 if (sty > stx) {
00595 stp = std::min(stx + FloatType(0.66) * (sty - stx), stp);
00596 }
00597 else {
00598 stp = std::max(stx + FloatType(0.66) * (sty - stx), stp);
00599 }
00600 }
00601 return info;
00602 }
00603
00604 }}
00605
00606 #endif // SCITBX_LINE_SEARCH_MORE_THUENTE_1994_RAW_H