SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Integration.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Roman Votyakov
8  *
9  * The abscissae and weights for Gauss-Kronrod rules are taken form
10  * QUADPACK, which is in public domain.
11  * http://www.netlib.org/quadpack/
12  *
13  * See header file for which functions are adapted from GNU Octave,
14  * file quadgk.m: Copyright (C) 2008-2012 David Bateman under GPLv3
15  * http://www.gnu.org/software/octave/
16  */
17 
19 
20 #ifdef HAVE_EIGEN3
21 
23 
24 using namespace shogun;
25 using namespace Eigen;
26 
27 namespace shogun
28 {
29 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
42 class CITransformFunction : public CFunction
43 {
44 public:
49  CITransformFunction(CFunction* f)
50  {
51  SG_REF(f);
52  m_f=f;
53  }
54 
55  virtual ~CITransformFunction() { SG_UNREF(m_f); }
56 
64  virtual float64_t operator() (float64_t x)
65  {
66  float64_t hx=1.0/(1.0-CMath::sq(x));
67  float64_t gx=x*hx;
68  float64_t dgx=(1.0+CMath::sq(x))*CMath::sq(hx);
69 
70  return (*m_f)(gx)*dgx;
71  }
72 
73 private:
75  CFunction* m_f;
76 };
77 
93 class CILTransformFunction : public CFunction
94 {
95 public:
101  CILTransformFunction(CFunction* f, float64_t b)
102  {
103  SG_REF(f);
104  m_f=f;
105  m_b=b;
106  }
107 
108  virtual ~CILTransformFunction() { SG_UNREF(m_f); }
109 
117  virtual float64_t operator() (float64_t x)
118  {
119  float64_t hx=1.0/(1.0+x);
120  float64_t gx=x*hx;
121  float64_t dgx=CMath::sq(hx);
122 
123  return -(*m_f)(m_b-CMath::sq(gx))*2*gx*dgx;
124  }
125 
126 private:
128  CFunction* m_f;
129 
131  float64_t m_b;
132 };
133 
149 class CIUTransformFunction : public CFunction
150 {
151 public:
157  CIUTransformFunction(CFunction* f, float64_t a)
158  {
159  SG_REF(f);
160  m_f=f;
161  m_a=a;
162  }
163 
164  virtual ~CIUTransformFunction() { SG_UNREF(m_f); }
165 
173  virtual float64_t operator() (float64_t x)
174  {
175  float64_t hx=1.0/(1.0-x);
176  float64_t gx=x*hx;
177  float64_t dgx=CMath::sq(hx);
178 
179  return (*m_f)(m_a+CMath::sq(gx))*2*gx*dgx;
180  }
181 
182 private:
184  CFunction* m_f;
185 
187  float64_t m_a;
188 };
189 
200 class CTransformFunction : public CFunction
201 {
202 public:
209  CTransformFunction(CFunction* f, float64_t a, float64_t b)
210  {
211  SG_REF(f);
212  m_f=f;
213  m_a=a;
214  m_b=b;
215  }
216 
217  virtual ~CTransformFunction() { SG_UNREF(m_f); }
218 
227  virtual float64_t operator() (float64_t x)
228  {
229  float64_t qw=(m_b-m_a)/4.0;
230  float64_t gx=qw*(x*(3.0-CMath::sq(x)))+(m_b+m_a)/2.0;
231  float64_t dgx=qw*3.0*(1.0-CMath::sq(x));
232 
233  return (*m_f)(gx)*dgx;
234  }
235 
236 private:
238  CFunction* m_f;
239 
241  float64_t m_a;
242 
244  float64_t m_b;
245 };
246 
247 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
248 
250  float64_t b, float64_t abs_tol, float64_t rel_tol, uint32_t max_iter,
251  index_t sn)
252 {
253  // check the parameters
254  REQUIRE(f, "Integrable function should not be NULL\n")
255  REQUIRE(abs_tol>0.0, "Absolute tolerance must be positive, but is %f\n",
256  abs_tol)
257  REQUIRE(rel_tol>0.0, "Relative tolerance must be positive, but is %f\n",
258  rel_tol)
259  REQUIRE(max_iter>0, "Maximum number of iterations must be greater than 0, "
260  "but is %d\n", max_iter)
261  REQUIRE(sn>0, "Initial number of subintervals must be greater than 0, "
262  "but is %d\n", sn)
263 
264  // integral evaluation function
265  typedef void TQuadGKEvaluationFunction(CFunction* f,
268 
269  TQuadGKEvaluationFunction* evaluate_quadgk;
270 
271  CFunction* tf;
272  float64_t ta;
273  float64_t tb;
274  float64_t q_sign;
275 
276  // negate integral value and swap a and b, if a>b
277  if (a>b)
278  {
279  ta=b;
280  tb=a;
281  q_sign=-1.0;
282  }
283  else
284  {
285  ta=a;
286  tb=b;
287  q_sign=1.0;
288  }
289 
290  // transform integrable function and domain of integration
291  if (a==-CMath::INFTY && b==CMath::INFTY)
292  {
293  tf=new CITransformFunction(f);
294  evaluate_quadgk=&evaluate_quadgk15;
295  ta=-1.0;
296  tb=1.0;
297  }
298  else if (a==-CMath::INFTY)
299  {
300  tf=new CILTransformFunction(f, b);
301  evaluate_quadgk=&evaluate_quadgk15;
302  ta=-1.0;
303  tb=0.0;
304  }
305  else if (b==CMath::INFTY)
306  {
307  tf=new CIUTransformFunction(f, a);
308  evaluate_quadgk=&evaluate_quadgk15;
309  ta=0.0;
310  tb=1.0;
311  }
312  else
313  {
314  tf=new CTransformFunction(f, a, b);
315  evaluate_quadgk=&evaluate_quadgk21;
316  ta=-1.0;
317  tb=1.0;
318  }
319 
320  // compute initial subintervals, by dividing domain [a, b] into sn
321  // parts
323 
324  // width of each subinterval
325  float64_t sw=(tb-ta)/sn;
326 
327  for (index_t i=0; i<sn; i++)
328  {
329  subs->push_back(ta+i*sw);
330  subs->push_back(ta+(i+1)*sw);
331  }
332 
333  // evaluate integrals on initial subintervals
336 
337  evaluate_quadgk(tf, subs, q_subs, err_subs);
338 
339  // compute value of integral and error on [a, b]
340  float64_t q=0.0;
341  float64_t err=0.0;
342 
343  for (index_t i=0; i<q_subs->get_num_elements(); i++)
344  q+=(*q_subs)[i];
345 
346  for (index_t i=0; i<err_subs->get_num_elements(); i++)
347  err+=(*err_subs)[i];
348 
349  // evaluate tolerance
350  float64_t tol=CMath::max(abs_tol, rel_tol*CMath::abs(q));
351 
352  // number of iterations
353  uint32_t iter=1;
354 
356 
357  while (err>tol && iter<max_iter)
358  {
359  // choose and bisect subintervals with estimated error, which
360  // is larger or equal to tolerance
361  for (index_t i=0; i<subs->get_num_elements()/2; i++)
362  {
363  if (CMath::abs((*err_subs)[i])>=tol*CMath::abs((*subs)[2*i+1]-
364  (*subs)[2*i])/(tb-ta))
365  {
366  // bisect subinterval
367  float64_t mid=((*subs)[2*i]+(*subs)[2*i+1])/2.0;
368 
369  new_subs->push_back((*subs)[2*i]);
370  new_subs->push_back(mid);
371  new_subs->push_back(mid);
372  new_subs->push_back((*subs)[2*i+1]);
373 
374  // subtract value of the integral and error on this
375  // subinterval from total value and error
376  q-=(*q_subs)[i];
377  err-=(*err_subs)[i];
378  }
379  }
380 
381  subs->set_array(new_subs->get_array(), new_subs->get_num_elements(),
382  new_subs->get_num_elements());
383 
384  new_subs->reset_array();
385 
386  // break if no new subintervals
387  if (!subs->get_num_elements())
388  break;
389 
390  // evaluate integrals on selected subintervals
391  evaluate_quadgk(tf, subs, q_subs, err_subs);
392 
393  for (index_t i=0; i<q_subs->get_num_elements(); i++)
394  q+=(*q_subs)[i];
395 
396  for (index_t i=0; i<err_subs->get_num_elements(); i++)
397  err+=(*err_subs)[i];
398 
399  // evaluate tolerance
400  tol=CMath::max(abs_tol, rel_tol*CMath::abs(q));
401 
402  iter++;
403  }
404 
405  SG_UNREF(new_subs);
406 
407  if (err>tol)
408  {
409  SG_SWARNING("Error tolerance not met. Estimated error is equal to %g "
410  "after %d iterations\n", err, iter)
411  }
412 
413  // clean up
414  SG_UNREF(subs);
415  SG_UNREF(q_subs);
416  SG_UNREF(err_subs);
417  SG_UNREF(tf);
418 
419  return q_sign*q;
420 }
421 
423 {
424  SG_REF(f);
425 
426  // evaluate integral using Gauss-Hermite 64-point rule
427  float64_t q=evaluate_quadgh64(f);
428 
429  SG_UNREF(f);
430 
431  return q;
432 }
433 
434 void CIntegration::evaluate_quadgk(CFunction* f, CDynamicArray<float64_t>* subs,
436  float64_t* xgk, float64_t* wg, float64_t* wgk)
437 {
438  // check the parameters
439  REQUIRE(f, "Integrable function should not be NULL\n")
440  REQUIRE(subs, "Array of subintervals should not be NULL\n")
441  REQUIRE(!(subs->get_array_size()%2), "Size of the array of subintervals "
442  "should be even\n")
443  REQUIRE(q, "Array of values of integrals should not be NULL\n")
444  REQUIRE(err, "Array of errors should not be NULL\n")
445  REQUIRE(n%2, "Order of Gauss-Kronrod should be odd\n")
446  REQUIRE(xgk, "Gauss-Kronrod nodes should not be NULL\n")
447  REQUIRE(wgk, "Gauss-Kronrod weights should not be NULL\n")
448  REQUIRE(wg, "Gauss weights should not be NULL\n")
449 
450  // create eigen representation of subs, xgk, wg, wgk
451  Map<MatrixXd> eigen_subs(subs->get_array(), 2, subs->get_num_elements()/2);
452  Map<VectorXd> eigen_xgk(xgk, n);
453  Map<VectorXd> eigen_wg(wg, n/2);
454  Map<VectorXd> eigen_wgk(wgk, n);
455 
456  // compute half width and centers of each subinterval
457  VectorXd eigen_hw=(eigen_subs.row(1)-eigen_subs.row(0))/2.0;
458  VectorXd eigen_center=eigen_subs.colwise().sum()/2.0;
459 
460  // compute Gauss-Kronrod nodes x for each subinterval: x=hw*xgk+center
461  MatrixXd x=eigen_hw*eigen_xgk.adjoint()+eigen_center*
462  (VectorXd::Ones(n)).adjoint();
463 
464  // compute ygk=f(x)
465  MatrixXd ygk(x.rows(), x.cols());
466 
467  for (index_t i=0; i<ygk.rows(); i++)
468  for (index_t j=0; j<ygk.cols(); j++)
469  ygk(i,j)=(*f)(x(i,j));
470 
471  // compute value of definite integral on each subinterval
472  VectorXd eigen_q=((ygk*eigen_wgk.asDiagonal()).rowwise().sum()).cwiseProduct(
473  eigen_hw);
474  q->set_array(eigen_q.data(), eigen_q.size());
475 
476  // choose function values for Gauss nodes
477  MatrixXd yg(ygk.rows(), ygk.cols()/2);
478 
479  for (index_t i=1, j=0; i<ygk.cols(); i+=2, j++)
480  yg.col(j)=ygk.col(i);
481 
482  // compute error on each subinterval
483  VectorXd eigen_err=(((yg*eigen_wg.asDiagonal()).rowwise().sum()).cwiseProduct(
484  eigen_hw)-eigen_q).array().abs();
485  err->set_array(eigen_err.data(), eigen_err.size());
486 }
487 
488 void CIntegration::evaluate_quadgk15(CFunction* f, CDynamicArray<float64_t>* subs,
490 {
491  static const index_t n=15;
492 
493  // Gauss-Kronrod nodes
494  static float64_t xgk[n]=
495  {
496  -0.991455371120812639206854697526329,
497  -0.949107912342758524526189684047851,
498  -0.864864423359769072789712788640926,
499  -0.741531185599394439863864773280788,
500  -0.586087235467691130294144838258730,
501  -0.405845151377397166906606412076961,
502  -0.207784955007898467600689403773245,
503  0.000000000000000000000000000000000,
504  0.207784955007898467600689403773245,
505  0.405845151377397166906606412076961,
506  0.586087235467691130294144838258730,
507  0.741531185599394439863864773280788,
508  0.864864423359769072789712788640926,
509  0.949107912342758524526189684047851,
510  0.991455371120812639206854697526329
511  };
512 
513  // Gauss weights
514  static float64_t wg[n/2]=
515  {
516  0.129484966168869693270611432679082,
517  0.279705391489276667901467771423780,
518  0.381830050505118944950369775488975,
519  0.417959183673469387755102040816327,
520  0.381830050505118944950369775488975,
521  0.279705391489276667901467771423780,
522  0.129484966168869693270611432679082
523  };
524 
525  // Gauss-Kronrod weights
526  static float64_t wgk[n]=
527  {
528  0.022935322010529224963732008058970,
529  0.063092092629978553290700663189204,
530  0.104790010322250183839876322541518,
531  0.140653259715525918745189590510238,
532  0.169004726639267902826583426598550,
533  0.190350578064785409913256402421014,
534  0.204432940075298892414161999234649,
535  0.209482141084727828012999174891714,
536  0.204432940075298892414161999234649,
537  0.190350578064785409913256402421014,
538  0.169004726639267902826583426598550,
539  0.140653259715525918745189590510238,
540  0.104790010322250183839876322541518,
541  0.063092092629978553290700663189204,
542  0.022935322010529224963732008058970
543  };
544 
545  // evaluate definite integral on each subinterval using Gauss-Kronrod rule
546  evaluate_quadgk(f, subs, q, err, n, xgk, wg, wgk);
547 }
548 
549 void CIntegration::evaluate_quadgk21(CFunction* f, CDynamicArray<float64_t>* subs,
551 {
552  static const index_t n=21;
553 
554  // Gauss-Kronrod nodes
555  static float64_t xgk[n]=
556  {
557  -0.995657163025808080735527280689003,
558  -0.973906528517171720077964012084452,
559  -0.930157491355708226001207180059508,
560  -0.865063366688984510732096688423493,
561  -0.780817726586416897063717578345042,
562  -0.679409568299024406234327365114874,
563  -0.562757134668604683339000099272694,
564  -0.433395394129247190799265943165784,
565  -0.294392862701460198131126603103866,
566  -0.148874338981631210884826001129720,
567  0.000000000000000000000000000000000,
568  0.148874338981631210884826001129720,
569  0.294392862701460198131126603103866,
570  0.433395394129247190799265943165784,
571  0.562757134668604683339000099272694,
572  0.679409568299024406234327365114874,
573  0.780817726586416897063717578345042,
574  0.865063366688984510732096688423493,
575  0.930157491355708226001207180059508,
576  0.973906528517171720077964012084452,
577  0.995657163025808080735527280689003
578  };
579 
580  // Gauss weights
581  static float64_t wg[n/2]=
582  {
583  0.066671344308688137593568809893332,
584  0.149451349150580593145776339657697,
585  0.219086362515982043995534934228163,
586  0.269266719309996355091226921569469,
587  0.295524224714752870173892994651338,
588  0.295524224714752870173892994651338,
589  0.269266719309996355091226921569469,
590  0.219086362515982043995534934228163,
591  0.149451349150580593145776339657697,
592  0.066671344308688137593568809893332
593  };
594 
595  // Gauss-Kronrod weights
596  static float64_t wgk[n]=
597  {
598  0.011694638867371874278064396062192,
599  0.032558162307964727478818972459390,
600  0.054755896574351996031381300244580,
601  0.075039674810919952767043140916190,
602  0.093125454583697605535065465083366,
603  0.109387158802297641899210590325805,
604  0.123491976262065851077958109831074,
605  0.134709217311473325928054001771707,
606  0.142775938577060080797094273138717,
607  0.147739104901338491374841515972068,
608  0.149445554002916905664936468389821,
609  0.147739104901338491374841515972068,
610  0.142775938577060080797094273138717,
611  0.134709217311473325928054001771707,
612  0.123491976262065851077958109831074,
613  0.109387158802297641899210590325805,
614  0.093125454583697605535065465083366,
615  0.075039674810919952767043140916190,
616  0.054755896574351996031381300244580,
617  0.032558162307964727478818972459390,
618  0.011694638867371874278064396062192
619  };
620 
621  evaluate_quadgk(f, subs, q, err, n, xgk, wg, wgk);
622 }
623 
624 float64_t CIntegration::evaluate_quadgh(CFunction* f, index_t n, float64_t* xgh,
625  float64_t* wgh)
626 {
627  // check the parameters
628  REQUIRE(f, "Integrable function should not be NULL\n");
629  REQUIRE(xgh, "Gauss-Hermite nodes should not be NULL\n");
630  REQUIRE(wgh, "Gauss-Hermite weights should not be NULL\n");
631 
632  float64_t q=0.0;
633 
634  for (index_t i=0; i<n; i++)
635  q+=wgh[i]*(*f)(xgh[i]);
636 
637  return q;
638 }
639 
640 float64_t CIntegration::evaluate_quadgh64(CFunction* f)
641 {
642  static const index_t n=64;
643 
644  // Gauss-Hermite nodes
645  static float64_t xgh[n]=
646  {
647  -10.52612316796054588332682628381528,
648  -9.895287586829539021204461477159608,
649  -9.373159549646721162545652439723862,
650  -8.907249099964769757295972885642943,
651  -8.477529083379863090564166344821916,
652  -8.073687285010225225858791140758144,
653  -7.68954016404049682844780422986949,
654  -7.321013032780949201189569363719477,
655  -6.965241120551107529242642193492688,
656  -6.620112262636027379036660108937914,
657  -6.284011228774828235418093195070243,
658  -5.955666326799486045344567180984366,
659  -5.634052164349972147249920483307154,
660  -5.318325224633270857323649515199378,
661  -5.007779602198768196443702627184136,
662  -4.701815647407499816097538015812822,
663  -4.399917168228137647767932535438923,
664  -4.101634474566656714970981238455522,
665  -3.806571513945360461165972000460225,
666  -3.514375935740906211539950586474333,
667  -3.224731291992035725848171110188419,
668  -2.93735082300462180968533902619139,
669  -2.651972435430635011005457785998431,
670  -2.368354588632401404111511265341516,
671  -2.086272879881762020832563302363221,
672  -1.805517171465544918903773574186889,
673  -1.525889140209863662948970133151528,
674  -1.24720015694311794069356453069359,
675  -0.9692694230711780167435414890191023,
676  -0.6919223058100445772682192875955947,
677  -0.4149888241210786845769291291996859,
678  -0.1383022449870097241150497679666744,
679  0.1383022449870097241150497679666744,
680  0.4149888241210786845769291291996859,
681  0.6919223058100445772682192875955947,
682  0.9692694230711780167435414890191023,
683  1.24720015694311794069356453069359,
684  1.525889140209863662948970133151528,
685  1.805517171465544918903773574186889,
686  2.086272879881762020832563302363221,
687  2.368354588632401404111511265341516,
688  2.651972435430635011005457785998431,
689  2.93735082300462180968533902619139,
690  3.224731291992035725848171110188419,
691  3.514375935740906211539950586474333,
692  3.806571513945360461165972000460225,
693  4.101634474566656714970981238455522,
694  4.399917168228137647767932535438923,
695  4.701815647407499816097538015812822,
696  5.007779602198768196443702627184136,
697  5.318325224633270857323649515199378,
698  5.634052164349972147249920483307154,
699  5.955666326799486045344567180984366,
700  6.284011228774828235418093195070243,
701  6.620112262636027379036660108937914,
702  6.965241120551107529242642193492688,
703  7.321013032780949201189569363719477,
704  7.68954016404049682844780422986949,
705  8.073687285010225225858791140758144,
706  8.477529083379863090564166344821916,
707  8.907249099964769757295972885642943,
708  9.373159549646721162545652439723862,
709  9.895287586829539021204461477159608,
710  10.52612316796054588332682628381528
711  };
712 
713  // Gauss-Hermite weights
714  static float64_t wgh[n]=
715  {
716  5.535706535856942820575463300987E-49,
717  1.6797479901081592186662883306299E-43,
718  3.4211380112557405043272218281457E-39,
719  1.557390624629763802309335380265E-35,
720  2.549660899112999256604766580441E-32,
721  1.92910359546496685030196877906707E-29,
722  7.8617977889259103690999914962788E-27,
723  1.911706883300642829958456965534449E-24,
724  2.982862784279851154478700702016E-22,
725  3.15225456650378141612134668341E-20,
726  2.35188471067581911695767591555844E-18,
727  1.28009339132243804163956329526337E-16,
728  5.218623726590847522957808513052588E-15,
729  1.628340730709720362084307081240893E-13,
730  3.95917776694772392723644586425458E-12,
731  7.61521725014545135331529567531937E-11,
732  1.1736167423215493435425064670822E-9,
733  1.465125316476109354926622003804004E-8,
734  1.495532936727247061102461692934817E-7,
735  1.258340251031184576157842180019028E-6,
736  8.7884992308503591814440474067043E-6,
737  5.125929135786274660821911412739621E-5,
738  2.509836985130624860823620179819094E-4,
739  0.001036329099507577663456741746283101,
740  0.00362258697853445876066812537162265,
741  0.01075604050987913704946517278667313,
742  0.0272031289536889184538348212614932,
743  0.0587399819640994345496889462518317,
744  0.1084983493061868406330258455060973,
745  0.1716858423490837020007279701237768,
746  0.2329947860626780466505660293325675,
747  0.2713774249413039779456065084184279,
748  0.2713774249413039779456065084184279,
749  0.2329947860626780466505660293325675,
750  0.1716858423490837020007279701237768,
751  0.1084983493061868406330258455060973,
752  0.0587399819640994345496889462518317,
753  0.0272031289536889184538348212614932,
754  0.01075604050987913704946517278667313,
755  0.00362258697853445876066812537162265,
756  0.001036329099507577663456741746283101,
757  2.509836985130624860823620179819094E-4,
758  5.125929135786274660821911412739621E-5,
759  8.7884992308503591814440474067043E-6,
760  1.258340251031184576157842180019028E-6,
761  1.495532936727247061102461692934817E-7,
762  1.465125316476109354926622003804004E-8,
763  1.1736167423215493435425064670822E-9,
764  7.61521725014545135331529567531937E-11,
765  3.95917776694772392723644586425458E-12,
766  1.628340730709720362084307081240893E-13,
767  5.218623726590847522957808513052588E-15,
768  1.28009339132243804163956329526337E-16,
769  2.35188471067581911695767591555844E-18,
770  3.15225456650378141612134668341E-20,
771  2.982862784279851154478700702016E-22,
772  1.911706883300642829958456965534449E-24,
773  7.8617977889259103690999914962788E-27,
774  1.92910359546496685030196877906707E-29,
775  2.549660899112999256604766580441E-32,
776  1.557390624629763802309335380265E-35,
777  3.4211380112557405043272218281457E-39,
778  1.6797479901081592186662883306299E-43,
779  5.535706535856942820575463300987E-49
780  };
781 
782  return evaluate_quadgh(f, n, xgh, wgh);
783 }
784 }
785 
786 #endif /* HAVE_EIGEN3 */

SHOGUN Machine Learning Toolbox - Documentation