SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MPDSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
12 #include <shogun/io/SGIO.h>
13 #include <shogun/lib/common.h>
15 
16 using namespace shogun;
17 
19 : CSVM()
20 {
21 }
22 
24 : CSVM(C, k, lab)
25 {
26 }
27 
29 {
30 }
31 
33 {
36  ASSERT(kernel)
37 
38  if (data)
39  {
40  if (m_labels->get_num_labels() != data->get_num_vectors())
41  SG_ERROR("Number of training vectors does not match number of labels\n")
42  kernel->init(data, data);
43  }
45 
46  //const float64_t nu=0.32;
47  const float64_t alpha_eps=1e-12;
48  const float64_t eps=get_epsilon();
49  const int64_t maxiter = 1L<<30;
50  //const bool nustop=false;
51  //const int32_t k=2;
52  const int32_t n=m_labels->get_num_labels();
53  ASSERT(n>0)
54  //const float64_t d = 1.0/n/nu; //NUSVC
55  const float64_t d = get_C1(); //CSVC
56  const float64_t primaleps=eps;
57  const float64_t dualeps=eps*n; //heuristic
58  int64_t niter=0;
59 
61  float64_t* alphas=SG_MALLOC(float64_t, n);
62  float64_t* dalphas=SG_MALLOC(float64_t, n);
63  //float64_t* hessres=SG_MALLOC(float64_t, 2*n);
64  float64_t* hessres=SG_MALLOC(float64_t, n);
65  //float64_t* F=SG_MALLOC(float64_t, 2*n);
66  float64_t* F=SG_MALLOC(float64_t, n);
67 
68  //float64_t hessest[2]={0,0};
69  //float64_t hstep[2];
70  //float64_t etas[2]={0,0};
71  //float64_t detas[2]={0,1}; //NUSVC
72  float64_t etas=0;
73  float64_t detas=0; //CSVC
74  float64_t hessest=0;
75  float64_t hstep;
76 
77  const float64_t stopfac = 1;
78 
79  bool primalcool;
80  bool dualcool;
81 
82  //if (nustop)
83  //etas[1] = 1;
84 
85  for (int32_t i=0; i<n; i++)
86  {
87  alphas[i]=0;
88  F[i]=((CBinaryLabels*) m_labels)->get_label(i);
89  //F[i+n]=-1;
90  hessres[i]=((CBinaryLabels*) m_labels)->get_label(i);
91  //hessres[i+n]=-1;
92  //dalphas[i]=F[i+n]*etas[1]; //NUSVC
93  dalphas[i]=-1; //CSVC
94  }
95 
96  // go ...
97  while (niter++ < maxiter)
98  {
99  int32_t maxpidx=-1;
100  float64_t maxpviol = -1;
101  //float64_t maxdviol = CMath::abs(detas[0]);
102  float64_t maxdviol = CMath::abs(detas);
103  bool free_alpha=false;
104 
105  //if (CMath::abs(detas[1])> maxdviol)
106  //maxdviol=CMath::abs(detas[1]);
107 
108  // compute kkt violations with correct sign ...
109  for (int32_t i=0; i<n; i++)
110  {
111  float64_t v=CMath::abs(dalphas[i]);
112 
113  if (alphas[i] > 0 && alphas[i] < d)
114  free_alpha=true;
115 
116  if ( (dalphas[i]==0) ||
117  (alphas[i]==0 && dalphas[i] >0) ||
118  (alphas[i]==d && dalphas[i] <0)
119  )
120  v=0;
121 
122  if (v > maxpviol)
123  {
124  maxpviol=v;
125  maxpidx=i;
126  } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
127  else if (v == maxpviol)
128  {
129  if (kernel_cache->is_cached(i))
130  maxpidx=i;
131  }
132  }
133 
134  if (maxpidx<0 || maxdviol<0)
135  SG_ERROR("no violation no convergence, should not happen!\n")
136 
137  // ... and evaluate stopping conditions
138  //if (nustop)
139  //stopfac = CMath::max(etas[1], 1e-10);
140  //else
141  //stopfac = 1;
142 
143  if (niter%10000 == 0)
144  {
145  float64_t obj=0;
146 
147  for (int32_t i=0; i<n; i++)
148  {
149  obj-=alphas[i];
150  for (int32_t j=0; j<n; j++)
151  obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
152  }
153 
154  SG_DEBUG("obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter)
155  }
156 
157  //for (int32_t i=0; i<n; i++)
158  // SG_DEBUG("alphas:%f dalphas:%f\n", alphas[i], dalphas[i])
159 
160  primalcool = (maxpviol < primaleps*stopfac);
161  dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
162 
163  // done?
164  if (primalcool && dualcool)
165  {
166  if (!free_alpha)
167  SG_INFO(" no free alpha, stopping! #iter=%d\n", niter)
168  else
169  SG_INFO(" done! #iter=%d\n", niter)
170  break;
171  }
172 
173 
174  ASSERT(maxpidx>=0 && maxpidx<n)
175  // hessian updates
176  hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
177  //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
178  //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
179 
180  hessest-=F[maxpidx]*hstep;
181  //hessest[0]-=F[maxpidx]*hstep[0];
182  //hessest[1]-=F[maxpidx+n]*hstep[1];
183 
184  // do primal updates ..
185  float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
186 
187  if (tmpalpha > d-alpha_eps)
188  tmpalpha = d;
189 
190  if (tmpalpha < 0+alpha_eps)
191  tmpalpha = 0;
192 
193  // update alphas & dalphas & detas ...
194  float64_t alphachange = tmpalpha - alphas[maxpidx];
195  alphas[maxpidx] = tmpalpha;
196 
197  KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
198  for (int32_t i=0; i<n; i++)
199  {
200  hessres[i]+=h[i]*hstep;
201  //hessres[i]+=h[i]*hstep[0];
202  //hessres[i+n]+=h[i]*hstep[1];
203  dalphas[i] +=h[i]*alphachange;
204  }
205  unlock_kernel_row(maxpidx);
206 
207  detas+=F[maxpidx]*alphachange;
208  //detas[0]+=F[maxpidx]*alphachange;
209  //detas[1]+=F[maxpidx+n]*alphachange;
210 
211  // if at primal minimum, do eta update ...
212  if (primalcool)
213  {
214  //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
215  float64_t etachange = detas/hessest;
216 
217  etas+=etachange;
218  //etas[0]+=etachange[0];
219  //etas[1]+=etachange[1];
220 
221  // update dalphas
222  for (int32_t i=0; i<n; i++)
223  dalphas[i]+= F[i] * etachange;
224  //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
225  }
226  }
227 
228  if (niter >= maxiter)
229  SG_WARNING("increase maxiter ... \n")
230 
231 
232  int32_t nsv=0;
233  for (int32_t i=0; i<n; i++)
234  {
235  if (alphas[i]>0)
236  nsv++;
237  }
238 
239 
240  create_new_model(nsv);
241  //set_bias(etas[0]/etas[1]);
242  set_bias(etas);
243 
244  int32_t j=0;
245  for (int32_t i=0; i<n; i++)
246  {
247  if (alphas[i]>0)
248  {
249  //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
250  set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i));
251  set_support_vector(j, i);
252  j++;
253  }
254  }
256  SG_INFO("obj = %.16f, rho = %.16f\n",get_objective(),get_bias())
257  SG_INFO("Number of SV: %ld\n", get_num_support_vectors())
258 
259  SG_FREE(alphas);
260  SG_FREE(dalphas);
261  SG_FREE(hessres);
262  SG_FREE(F);
263  delete kernel_cache;
264 
265  return true;
266 }

SHOGUN Machine Learning Toolbox - Documentation