SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
MPDSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
12 #include <shogun/io/SGIO.h>
13 #include <shogun/lib/common.h>
15 #include <shogun/lib/Signal.h>
16 
17 using namespace shogun;
18 
20 : CSVM()
21 {
22 }
23 
25 : CSVM(C, k, lab)
26 {
27 }
28 
30 {
31 }
32 
34 {
37  ASSERT(kernel)
38 
39  if (data)
40  {
41  if (m_labels->get_num_labels() != data->get_num_vectors())
42  SG_ERROR("Number of training vectors does not match number of labels\n")
43  kernel->init(data, data);
44  }
46 
47  //const float64_t nu=0.32;
48  const float64_t alpha_eps=1e-12;
49  const float64_t eps=get_epsilon();
50  const int64_t maxiter = 1L<<30;
51  //const bool nustop=false;
52  //const int32_t k=2;
53  const int32_t n=m_labels->get_num_labels();
54  ASSERT(n>0)
55  //const float64_t d = 1.0/n/nu; //NUSVC
56  const float64_t d = get_C1(); //CSVC
57  const float64_t primaleps=eps;
58  const float64_t dualeps=eps*n; //heuristic
59  int64_t niter=0;
60 
62  float64_t* alphas=SG_MALLOC(float64_t, n);
63  float64_t* dalphas=SG_MALLOC(float64_t, n);
64  //float64_t* hessres=SG_MALLOC(float64_t, 2*n);
65  float64_t* hessres=SG_MALLOC(float64_t, n);
66  //float64_t* F=SG_MALLOC(float64_t, 2*n);
67  float64_t* F=SG_MALLOC(float64_t, n);
68 
69  //float64_t hessest[2]={0,0};
70  //float64_t hstep[2];
71  //float64_t etas[2]={0,0};
72  //float64_t detas[2]={0,1}; //NUSVC
73  float64_t etas=0;
74  float64_t detas=0; //CSVC
75  float64_t hessest=0;
76  float64_t hstep;
77 
78  const float64_t stopfac = 1;
79 
80  bool primalcool;
81  bool dualcool;
83 
84  //if (nustop)
85  //etas[1] = 1;
86 
87  for (int32_t i=0; i<n; i++)
88  {
89  alphas[i]=0;
90  F[i]=((CBinaryLabels*) m_labels)->get_label(i);
91  //F[i+n]=-1;
92  hessres[i]=((CBinaryLabels*) m_labels)->get_label(i);
93  //hessres[i+n]=-1;
94  //dalphas[i]=F[i+n]*etas[1]; //NUSVC
95  dalphas[i]=-1; //CSVC
96  }
97 
98  // go ...
99  while (niter++ < maxiter && !CSignal::cancel_computations())
100  {
101  int32_t maxpidx=-1;
102  float64_t maxpviol = -1;
103  //float64_t maxdviol = CMath::abs(detas[0]);
104  float64_t maxdviol = CMath::abs(detas);
105  bool free_alpha=false;
106 
107  //if (CMath::abs(detas[1])> maxdviol)
108  //maxdviol=CMath::abs(detas[1]);
109 
110  // compute kkt violations with correct sign ...
111  for (int32_t i=0; i<n; i++)
112  {
113  float64_t v=CMath::abs(dalphas[i]);
114 
115  if (alphas[i] > 0 && alphas[i] < d)
116  free_alpha=true;
117 
118  if ( (dalphas[i]==0) ||
119  (alphas[i]==0 && dalphas[i] >0) ||
120  (alphas[i]==d && dalphas[i] <0)
121  )
122  v=0;
123 
124  if (v > maxpviol)
125  {
126  maxpviol=v;
127  maxpidx=i;
128  } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
129  else if (v == maxpviol)
130  {
131  if (kernel_cache->is_cached(i))
132  maxpidx=i;
133  }
134  }
135 
136  if (maxpidx<0 || maxdviol<0)
137  SG_ERROR("no violation no convergence, should not happen!\n")
138 
139  // ... and evaluate stopping conditions
140  //if (nustop)
141  //stopfac = CMath::max(etas[1], 1e-10);
142  //else
143  //stopfac = 1;
144 
145  if (niter%10000 == 0)
146  {
147  float64_t obj=0;
148 
149  for (int32_t i=0; i<n; i++)
150  {
151  obj-=alphas[i];
152  for (int32_t j=0; j<n; j++)
153  obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
154  }
155 
156  SG_DEBUG("obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter)
157  }
158 
159  //for (int32_t i=0; i<n; i++)
160  // SG_DEBUG("alphas:%f dalphas:%f\n", alphas[i], dalphas[i])
161 
162  primalcool = (maxpviol < primaleps*stopfac);
163  dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
164 
165  // done?
166  if (primalcool && dualcool)
167  {
168  if (!free_alpha)
169  SG_INFO(" no free alpha, stopping! #iter=%d\n", niter)
170  else
171  SG_INFO(" done! #iter=%d\n", niter)
172  break;
173  }
174 
175 
176  ASSERT(maxpidx>=0 && maxpidx<n)
177  // hessian updates
178  hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
179  //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
180  //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
181 
182  hessest-=F[maxpidx]*hstep;
183  //hessest[0]-=F[maxpidx]*hstep[0];
184  //hessest[1]-=F[maxpidx+n]*hstep[1];
185 
186  // do primal updates ..
187  float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
188 
189  if (tmpalpha > d-alpha_eps)
190  tmpalpha = d;
191 
192  if (tmpalpha < 0+alpha_eps)
193  tmpalpha = 0;
194 
195  // update alphas & dalphas & detas ...
196  float64_t alphachange = tmpalpha - alphas[maxpidx];
197  alphas[maxpidx] = tmpalpha;
198 
199  KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
200  for (int32_t i=0; i<n; i++)
201  {
202  hessres[i]+=h[i]*hstep;
203  //hessres[i]+=h[i]*hstep[0];
204  //hessres[i+n]+=h[i]*hstep[1];
205  dalphas[i] +=h[i]*alphachange;
206  }
207  unlock_kernel_row(maxpidx);
208 
209  detas+=F[maxpidx]*alphachange;
210  //detas[0]+=F[maxpidx]*alphachange;
211  //detas[1]+=F[maxpidx+n]*alphachange;
212 
213  // if at primal minimum, do eta update ...
214  if (primalcool)
215  {
216  //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
217  float64_t etachange = detas/hessest;
218 
219  etas+=etachange;
220  //etas[0]+=etachange[0];
221  //etas[1]+=etachange[1];
222 
223  // update dalphas
224  for (int32_t i=0; i<n; i++)
225  dalphas[i]+= F[i] * etachange;
226  //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
227  }
228  }
229 
230  if (niter >= maxiter)
231  SG_WARNING("increase maxiter ... \n")
232 
233 
234  int32_t nsv=0;
235  for (int32_t i=0; i<n; i++)
236  {
237  if (alphas[i]>0)
238  nsv++;
239  }
240 
241 
242  create_new_model(nsv);
243  //set_bias(etas[0]/etas[1]);
244  set_bias(etas);
245 
246  int32_t j=0;
247  for (int32_t i=0; i<n; i++)
248  {
249  if (alphas[i]>0)
250  {
251  //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
252  set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i));
253  set_support_vector(j, i);
254  j++;
255  }
256  }
258  SG_INFO("obj = %.16f, rho = %.16f\n",get_objective(),get_bias())
259  SG_INFO("Number of SV: %ld\n", get_num_support_vectors())
260 
261  SG_FREE(alphas);
262  SG_FREE(dalphas);
263  SG_FREE(hessres);
264  SG_FREE(F);
265  delete kernel_cache;
266 
267  return true;
268 }
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:96
#define SG_INFO(...)
Definition: SGIO.h:117
virtual ELabelType get_label_type() const =0
binary labels +1/-1
Definition: LabelTypes.h:18
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:365
#define SG_ERROR(...)
Definition: SGIO.h:128
float64_t kernel(int32_t idx_a, int32_t idx_b)
void unlock_kernel_row(int32_t i)
Definition: MPDSVM.h:102
float64_t compute_H(int32_t i, int32_t j)
Definition: MPDSVM.h:65
float64_t get_C1()
Definition: SVM.h:161
KERNELCACHE_ELEM * lock_kernel_row(int32_t i)
Definition: MPDSVM.h:76
float64_t compute_svm_dual_objective()
Definition: SVM.cpp:237
#define ASSERT(x)
Definition: SGIO.h:200
virtual bool train_machine(CFeatures *data=NULL)
Definition: MPDSVM.cpp:33
void set_bias(float64_t bias)
CCache< KERNELCACHE_ELEM > * kernel_cache
Definition: MPDSVM.h:108
static void clear_cancel()
Definition: Signal.cpp:126
double float64_t
Definition: common.h:60
bool set_alpha(int32_t idx, float64_t val)
bool set_support_vector(int32_t idx, int32_t val)
virtual ~CMPDSVM()
Definition: MPDSVM.cpp:29
static bool cancel_computations()
Definition: Signal.h:111
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
A generic Support Vector Machine Interface.
Definition: SVM.h:49
The Kernel base class.
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
int32_t get_cache_size()
bool is_cached(int64_t number)
Definition: Cache.h:130
#define SG_WARNING(...)
Definition: SGIO.h:127
virtual bool has_features()
float64_t get_epsilon()
Definition: SVM.h:149
float64_t get_objective()
Definition: SVM.h:218
float64_t KERNELCACHE_ELEM
Definition: kernel/Kernel.h:35
bool create_new_model(int32_t num)
static T abs(T a)
Definition: Math.h:175

SHOGUN Machine Learning Toolbox - Documentation