Main Lemma Repository

bicgstab.h 2.6KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /*
  2. modifed from http://www.netlib.org/templates/cpp//
  3. Iterative template routine -- BiCGSTAB
  4. BiCGSTAB solves the unsymmetric linear system Ax = b
  5. using the Preconditioned BiConjugate Gradient Stabilized method
  6. BiCGSTAB follows the algorithm described on p. 27 of the
  7. SIAM Templates book.
  8. The return value indicates convergence within max_iter (input)
  9. iterations (0), or no convergence within max_iter iterations (1).
  10. Upon successful return, output arguments have the following values:
  11. x -- approximate solution to Ax = b
  12. max_iter -- the number of iterations performed before the
  13. tolerance was reached
  14. tol -- the residual after the final iteration
  15. */
  16. #include <iostream>
  17. #include <fstream>
  18. #include "lemma.h"
  19. namespace Lemma {
  20. template <typename Scalar>
  21. VectorXr BiCGSTAB(const MatrixXr &A, const VectorXr &x0, const VectorXr &b,
  22. const MatrixXr &M, int &max_iter, Scalar& tol) {
  23. Scalar resid;
  24. Scalar rho_1(0), rho_2(0), alpha(0), beta(0), omega(0);
  25. VectorXr p, phat, s, shat, t, v;
  26. VectorXr x = x0;
  27. Scalar normb = b.norm();
  28. VectorXr r = b - A * x;
  29. VectorXr rtilde = r;
  30. if (normb == 0.0)
  31. normb = 1;
  32. if ((resid = r.norm() / normb) <= tol) {
  33. tol = resid;
  34. max_iter = 0;
  35. //return 0;
  36. return x;
  37. }
  38. for (int i = 1; i <= max_iter; i++) {
  39. rho_1 = rtilde.dot(r);
  40. if (rho_1 == 0) {
  41. tol = r.norm() / normb;
  42. //return 2;
  43. return x;
  44. }
  45. if (i == 1)
  46. p = r;
  47. else {
  48. beta = (rho_1/rho_2) * (alpha/omega);
  49. p = r + beta * (p - omega * v);
  50. }
  51. phat = M*p; //M.solve(p);
  52. v = A * phat;
  53. alpha = rho_1 / rtilde.dot(v);
  54. s = r - alpha * v;
  55. if ((resid = s.norm()/normb) < tol) {
  56. x += alpha * phat;
  57. tol = resid;
  58. //return 0;
  59. return x;
  60. }
  61. shat = M*s;//M.solve(s);
  62. t = A * shat;
  63. omega = t.dot(s) / t.dot(t);
  64. x += alpha * phat + omega * shat;
  65. r = s - omega * t;
  66. rho_2 = rho_1;
  67. if ((resid = r.norm() / normb) < tol) {
  68. tol = resid;
  69. max_iter = i;
  70. //return 0;
  71. return x;
  72. }
  73. if (omega == 0) {
  74. tol = r.norm() / normb;
  75. //return 3;
  76. return x;
  77. }
  78. }
  79. tol = resid;
  80. return x;
  81. }
  82. }
  83. /* vim: set tabstop=4 expandtab: */
  84. /* vim: set filetype=cpp: */