SHOGUN  4.2.0
ConjugateOrthogonalCGSolver.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Soumyajit De
8  */
9
10 #include <shogun/lib/common.h>
11
12
13 #include <shogun/lib/SGVector.h>
14 #include <shogun/lib/Time.h>
17
20 using namespace Eigen;
21
22 namespace shogun
23 {
24
25 CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver()
27 {
28  SG_GCDEBUG("%s created (%p)\n", this->get_name(), this);
29 }
30
32  : CIterativeLinearSolver<complex128_t, float64_t>(store_residuals)
33 {
34  SG_GCDEBUG("%s created (%p)\n", this->get_name(), this);
35 }
36
38 {
39  SG_GCDEBUG("%s destroyed (%p)\n", this->get_name(), this);
40 }
41
44 {
45  SG_DEBUG("CConjugateOrthogonalCGSolver::solve(): Entering..\n");
46
47  // sanity check
48  REQUIRE(A, "Operator is NULL!\n");
49  REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch!\n, %d vs %d",
50  A->get_dimension(), b.vlen);
51
52  // the final solution vector, initial guess is 0
53  SGVector<complex128_t> result(b.vlen);
54  result.set_const(0.0);
55
56  // the rest of the part hinges on eigen3 for computing norms
57  Map<VectorXcd> x(result.vector, result.vlen);
58  Map<VectorXd> b_map(b.vector, b.vlen);
59
60  // direction vector
61  SGVector<complex128_t> p_(result.vlen);
62  Map<VectorXcd> p(p_.vector, p_.vlen);
63
64  // residual r_i=b-Ax_i, here x_0=[0], so r_0=b
65  VectorXcd r=b_map.cast<complex128_t>();
66
67  // initial direction is same as residual
68  p=r;
69
70  // the iterator for this iterative solver
73
74  // start the timer
75  CTime time;
76  time.start();
77
78  // set the residuals to zero
81
82  // CG iteration begins
83  complex128_t r_norm2=r.transpose()*r;
84
85  for (it.begin(r); !it.end(r); ++it)
86  {
87  SG_DEBUG("CG iteration %d, residual norm %f\n",
88  it.get_iter_info().iteration_count,
89  it.get_iter_info().residual_norm);
90
92  {
93  m_residuals[it.get_iter_info().iteration_count]
94  =it.get_iter_info().residual_norm;
95  }
96
97  // apply linear operator to the direction vector
98  SGVector<complex128_t> Ap_=A->apply(p_);
99  Map<VectorXcd> Ap(Ap_.vector, Ap_.vlen);
100
101  // compute p^{T}Ap, if zero, failure
102  complex128_t p_T_times_Ap=p.transpose()*Ap;
103  if (p_T_times_Ap==0.0)
104  break;
105
106  // compute the alpha parameter of CG
107  complex128_t alpha=r_norm2/p_T_times_Ap;
108
109  // update the solution vector and residual
110  // x_{i}=x_{i-1}+\alpha_{i}p
111  x+=alpha*p;
112
113  // r_{i}=r_{i-1}-\alpha_{i}p
114  r-=alpha*Ap;
115
116  // compute new ||r||_{2}, if zero, converged
117  complex128_t r_norm2_i=r.transpose()*r;
118  if (r_norm2_i==0.0)
119  break;
120
121  // compute the beta parameter of CG
122  complex128_t beta=r_norm2_i/r_norm2;
123
124  // update direction, and ||r||_{2}
125  r_norm2=r_norm2_i;
126  p=r+beta*p;
127  }
128
129  float64_t elapsed=time.cur_time_diff();
130
131  if (!it.succeeded(r))
132  SG_WARNING("Did not converge!\n");
133
134  SG_INFO("Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n",
135  it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed);
136
137  SG_DEBUG("CConjugateOrthogonalCGSolver::solve(): Leaving..\n");
138  return result;
139 }
140
141 }
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:47
#define SG_INFO(...)
Definition: SGIO.h:118
std::complex< float64_t > complex128_t
Definition: common.h:67
void begin(const VectorXt &residual)
const index_t get_dimension() const
Definition: SGMatrix.h:20
#define REQUIRE(x,...)
Definition: SGIO.h:206
const bool end(const VectorXt &residual)
float64_t cur_time_diff(bool verbose=false)
Definition: Time.cpp:68
index_t vlen
Definition: SGVector.h:494
template class that is used as an iterator for an iterative linear solver. In the iteration of solvin...
#define SG_GCDEBUG(...)
Definition: SGIO.h:102
abstract template base for all iterative linear solvers such as conjugate gradient (CG) solvers...
double float64_t
Definition: common.h:50
float64_t start(bool verbose=false)
Definition: Time.cpp:59
virtual SGVector< complex128_t > solve(CLinearOperator< complex128_t > *A, SGVector< float64_t > b)
virtual SGVector< T > apply(SGVector< T > b) const =0
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
const bool succeeded(const VectorXt &residual)
#define SG_WARNING(...)
Definition: SGIO.h:128
void set_const(T const_elem)
Definition: SGVector.cpp:150

SHOGUN Machine Learning Toolbox - Documentation