SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
LinearTimeMMD.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2012 - 2013 Heiko Strathmann
4  * Written (w) 2014 - 2017 Soumyajit De
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  */
31 
32 #include <shogun/io/SGIO.h>
33 #include <shogun/lib/SGMatrix.h>
41 
42 using namespace shogun;
43 using namespace internal;
44 
46 {
47 }
48 
49 CLinearTimeMMD::CLinearTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q) : CStreamingMMD()
50 {
51  set_p(samples_from_p);
52  set_q(samples_from_q);
53 }
54 
56 {
57 }
58 
60 {
61  auto& data_mgr=get_data_mgr();
62  auto min_blocksize=data_mgr.get_min_blocksize();
63  if (min_blocksize==2)
64  {
65  // only possible when number of samples from both the distributions are the same
66  auto N=data_mgr.num_samples_at(0);
67  for (auto i=2; i<N; ++i)
68  {
69  if (N%i==0)
70  {
71  min_blocksize=i*2;
72  break;
73  }
74  }
75  }
76  data_mgr.set_blocksize(min_blocksize);
77  data_mgr.set_num_blocks_per_burst(num_blocks_per_burst);
78  SG_SDEBUG("Block contains %d and %d samples, from P and Q respectively!\n", data_mgr.blocksize_at(0), data_mgr.blocksize_at(1));
79 }
80 
81 const std::function<float32_t(SGMatrix<float32_t>)> CLinearTimeMMD::get_direct_estimation_method() const
82 {
83  return mmd::WithinBlockDirect();
84 }
85 
86 float64_t CLinearTimeMMD::normalize_statistic(float64_t statistic) const
87 {
88  const DataManager& data_mgr = get_data_mgr();
89  const index_t Nx = data_mgr.num_samples_at(0);
90  const index_t Ny = data_mgr.num_samples_at(1);
91  return CMath::sqrt(Nx * Ny / float64_t(Nx + Ny)) * statistic;
92 }
93 
94 const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const
95 {
96  const DataManager& data_mgr = get_data_mgr();
97  const index_t Bx = data_mgr.blocksize_at(0);
98  const index_t By = data_mgr.blocksize_at(1);
99  const index_t B = Bx + By;
101  {
102  return variance * B * (B - 2) / 16;
103  }
104  return variance * Bx * By * (Bx - 1) * (By - 1) / (B - 1) / (B - 2);
105 }
106 
107 const float64_t CLinearTimeMMD::gaussian_variance(float64_t variance) const
108 {
109  const DataManager& data_mgr = get_data_mgr();
110  const index_t Bx = data_mgr.blocksize_at(0);
111  const index_t By = data_mgr.blocksize_at(1);
112  const index_t B = Bx + By;
114  {
115  return variance * 4 / (B - 2);
116  }
117  return variance * (B - 1) * (B - 2) / (Bx - 1) / (By - 1) / B;
118 }
119 
121 {
122  float64_t result = 0;
124  {
125  case NAM_MMD1_GAUSSIAN:
126  {
127  float64_t sigma_sq = gaussian_variance(compute_variance());
128  float64_t std_dev = CMath::sqrt(sigma_sq);
129  result = 1.0 - CStatistics::normal_cdf(statistic, std_dev);
130  break;
131  }
132  case NAM_PERMUTATION:
133  {
134  SG_SERROR("Null approximation via permutation does not make sense "
135  "for linear time MMD. Use the Gaussian approximation instead.\n");
136  break;
137  }
138  default:
139  {
140  result = CHypothesisTest::compute_p_value(statistic);
141  break;
142  }
143  }
144  return result;
145 }
146 
148 {
149  float64_t result = 0;
151  {
152  case NAM_MMD1_GAUSSIAN:
153  {
154  float64_t sigma_sq = gaussian_variance(compute_variance());
155  float64_t std_dev = CMath::sqrt(sigma_sq);
156  result = 1.0 - CStatistics::inverse_normal_cdf(1 - alpha, 0, std_dev);
157  break;
158  }
159  default:
160  {
161  result = CHypothesisTest::compute_threshold(alpha);
162  break;
163  }
164  }
165  return result;
166 }
167 
168 const char* CLinearTimeMMD::get_name() const
169 {
170  return "LinearTimeMMD";
171 }
const ENullApproximationMethod get_null_approximation_method() const
virtual const char * get_name() const
index_t & num_samples_at(index_t i)
int32_t index_t
Definition: common.h:72
virtual float64_t compute_p_value(float64_t statistic)
void set_num_blocks_per_burst(index_t num_blocks_per_burst)
virtual float64_t compute_variance()
const EStatisticType get_statistic_type() const
virtual void set_p(CFeatures *samples_from_p)
virtual float64_t compute_threshold(float64_t alpha)
virtual void set_q(CFeatures *samples_from_q)
virtual float64_t compute_threshold(float64_t alpha)
double float64_t
Definition: common.h:60
static float64_t inverse_normal_cdf(float64_t y0, float64_t mean=0, float64_t std_dev=1)
Definition: Statistics.cpp:347
internal::DataManager & get_data_mgr()
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:167
The class Features is the base class of all feature objects.
Definition: Features.h:68
#define SG_SERROR(...)
Definition: SGIO.h:178
static float64_t normal_cdf(float64_t x, float64_t std_dev=1)
Definition: Statistics.cpp:507
Class DataManager for fetching/streaming test data block-wise. It can handle data coming from multipl...
Definition: DataManager.h:63
static float32_t sqrt(float32_t x)
Definition: Math.h:454
virtual float64_t compute_p_value(float64_t statistic)
const index_t blocksize_at(index_t i) const

SHOGUN Machine Learning Toolbox - Documentation