SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LaRank.h
Go to the documentation of this file.
1 // -*- C++ -*-
2 // Main functions of the LaRank algorithm for soving Multiclass SVM
3 // Copyright (C) 2008- Antoine Bordes
4 // Shogun specific adjustments (w) 2009 Soeren Sonnenburg
5 
6 // This library is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 2.1 of the License, or (at your option) any later version.
10 //
11 // This program is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
15 //
16 // You should have received a copy of the GNU Lesser General Public
17 // License along with this library; if not, write to the Free Software
18 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 //
20 /***********************************************************************
21  *
22  * LUSH Lisp Universal Shell
23  * Copyright (C) 2002 Leon Bottou, Yann Le Cun, AT&T Corp, NECI.
24  * Includes parts of TL3:
25  * Copyright (C) 1987-1999 Leon Bottou and Neuristique.
26  * Includes selected parts of SN3.2:
27  * Copyright (C) 1991-2001 AT&T Corp.
28  *
29  * This program is free software; you can redistribute it and/or modify
30  * it under the terms of the GNU General Public License as published by
31  * the Free Software Foundation; either version 2 of the License, or
32  * (at your option) any later version.
33  *
34  * This program is distributed in the hope that it will be useful,
35  * but WITHOUT ANY WARRANTY; without even the implied warranty of
36  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
37  * GNU General Public License for more details.
38  *
39  * You should have received a copy of the GNU General Public License
40  * along with this program; if not, write to the Free Software
41  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
42  *
43  ***********************************************************************/
44 
45 /***********************************************************************
46  * $Id: kcache.h,v 1.8 2007/01/25 22:42:09 leonb Exp $
47  **********************************************************************/
48 
49 #ifndef LARANK_H
50 #define LARANK_H
51 
52 #include <ctime>
53 #include <vector>
54 #include <algorithm>
55 #include <sys/time.h>
56 #include <set>
57 #include <map>
58 #define STDEXT_NAMESPACE __gnu_cxx
59 #define std_hash_map std::map
60 #define std_hash_set std::set
61 
62 #include <shogun/lib/config.h>
63 
64 #include <shogun/io/SGIO.h>
65 #include <shogun/kernel/Kernel.h>
67 
68 namespace shogun
69 {
70 #ifndef DOXYGEN_SHOULD_SKIP_THIS
71  struct larank_kcache_s;
72  typedef struct larank_kcache_s larank_kcache_t;
73  struct larank_kcache_s
74  {
75  CKernel* func;
76  larank_kcache_t *prevbuddy;
77  larank_kcache_t *nextbuddy;
78  int64_t maxsize;
79  int64_t cursize;
80  int32_t l;
81  int32_t *i2r;
82  int32_t *r2i;
83  int32_t maxrowlen;
84  /* Rows */
85  int32_t *rsize;
86  float32_t *rdiag;
87  float32_t **rdata;
88  int32_t *rnext;
89  int32_t *rprev;
90  int32_t *qnext;
91  int32_t *qprev;
92  };
93 
94  /*
95  ** OUTPUT: one per class of the raining set, keep tracks of support
96  * vectors and their beta coefficients
97  */
98  class LaRankOutput
99  {
100  public:
101  LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
102  {
103  }
104  virtual ~LaRankOutput ()
105  {
106  destroy();
107  }
108 
109  // Initializing an output class (basically creating a kernel cache for it)
110  void initialize (CKernel* kfunc, int64_t cache);
111 
112  // Destroying an output class (basically destroying the kernel cache)
113  void destroy ();
114 
115  // !Important! Computing the score of a given input vector for the actual output
116  float64_t computeScore (int32_t x_id);
117 
118  // !Important! Computing the gradient of a given input vector for the actual output
119  float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
120 
121  // Updating the solution in the actual output
122  void update (int32_t x_id, float64_t lambda, float64_t gp);
123 
124  // Linking the cache of this output to the cache of an other "buddy" output
125  // so that if a requested value is not found in this cache, you can
126  // ask your buddy if it has it.
127  void set_kernel_buddy (larank_kcache_t * bud);
128 
129  // Removing useless support vectors (for which beta=0)
130  int32_t cleanup ();
131 
132  // --- Below are information or "get" functions --- //
133 
134  //
135  inline larank_kcache_t *getKernel () const
136  {
137  return kernel;
138  }
139  //
140  inline int32_t get_l () const
141  {
142  return l;
143  }
144 
145  //
146  float64_t getW2 ();
147 
148  //
149  float64_t getKii (int32_t x_id);
150 
151  //
152  float64_t getBeta (int32_t x_id);
153 
154  //
155  inline float32_t* getBetas () const
156  {
157  return beta;
158  }
159 
160  //
161  float64_t getGradient (int32_t x_id);
162 
163  //
164  bool isSupportVector (int32_t x_id) const;
165 
166  //
167  int32_t getSV (float32_t* &sv) const;
168 
169  private:
170  // the solution of LaRank relative to the actual class is stored in
171  // this parameters
172  float32_t* beta; // Beta coefficiens
173  float32_t* g; // Strored gradient derivatives
174  larank_kcache_t *kernel; // Cache for kernel values
175  int32_t l; // Number of support vectors
176  };
177 
178  /*
179  ** LARANKPATTERN: to keep track of the support patterns
180  */
181  class LaRankPattern
182  {
183  public:
184  LaRankPattern (int32_t x_index, int32_t label)
185  : x_id (x_index), y (label) {}
186  LaRankPattern ()
187  : x_id (0) {}
188 
189  bool exists () const
190  {
191  return x_id >= 0;
192  }
193 
194  void clear ()
195  {
196  x_id = -1;
197  }
198 
199  int32_t x_id;
200  int32_t y;
201  };
202 
203  /*
204  ** LARANKPATTERNS: the collection of support patterns
205  */
206  class LaRankPatterns
207  {
208  public:
209  LaRankPatterns () {}
210  ~LaRankPatterns () {}
211 
212  void insert (const LaRankPattern & pattern)
213  {
214  if (!isPattern (pattern.x_id))
215  {
216  if (freeidx.size ())
217  {
218  std_hash_set < uint32_t >::iterator it = freeidx.begin ();
219  patterns[*it] = pattern;
220  x_id2rank[pattern.x_id] = *it;
221  freeidx.erase (it);
222  }
223  else
224  {
225  patterns.push_back (pattern);
226  x_id2rank[pattern.x_id] = patterns.size () - 1;
227  }
228  }
229  else
230  {
231  int32_t rank = getPatternRank (pattern.x_id);
232  patterns[rank] = pattern;
233  }
234  }
235 
236  void remove (uint32_t i)
237  {
238  x_id2rank[patterns[i].x_id] = 0;
239  patterns[i].clear ();
240  freeidx.insert (i);
241  }
242 
243  bool empty () const
244  {
245  return patterns.size () == freeidx.size ();
246  }
247 
248  uint32_t size () const
249  {
250  return patterns.size () - freeidx.size ();
251  }
252 
253  LaRankPattern & sample ()
254  {
255  ASSERT (!empty ())
256  while (true)
257  {
258  uint32_t r = CMath::random(uint32_t(0), uint32_t(patterns.size ()-1));
259  if (patterns[r].exists ())
260  return patterns[r];
261  }
262  return patterns[0];
263  }
264 
265  uint32_t getPatternRank (int32_t x_id)
266  {
267  return x_id2rank[x_id];
268  }
269 
270  bool isPattern (int32_t x_id)
271  {
272  return x_id2rank[x_id] != 0;
273  }
274 
275  LaRankPattern & getPattern (int32_t x_id)
276  {
277  uint32_t rank = x_id2rank[x_id];
278  return patterns[rank];
279  }
280 
281  uint32_t maxcount () const
282  {
283  return patterns.size ();
284  }
285 
286  LaRankPattern & operator [] (uint32_t i)
287  {
288  return patterns[i];
289  }
290 
291  const LaRankPattern & operator [] (uint32_t i) const
292  {
293  return patterns[i];
294  }
295 
296  private:
297  std_hash_set < uint32_t >freeidx;
298  std::vector < LaRankPattern > patterns;
299  std_hash_map < int32_t, uint32_t >x_id2rank;
300  };
301 
302 
303 #endif // DOXYGEN_SHOULD_SKIP_THIS
304 
305 
309  class CLaRank: public CMulticlassSVM
310  {
311  public:
312  CLaRank ();
313 
320  CLaRank(float64_t C, CKernel* k, CLabels* lab);
321 
322  virtual ~CLaRank ();
323 
324  // LEARNING FUNCTION: add new patterns and run optimization steps
325  // selected with adaptative schedule
330  virtual int32_t add (int32_t x_id, int32_t yi);
331 
332  // PREDICTION FUNCTION: main function in la_rank_classify
336  virtual int32_t predict (int32_t x_id);
337 
339  virtual void destroy ();
340 
341  // Compute Duality gap (costly but used in stopping criteria in batch mode)
343  virtual float64_t computeGap ();
344 
345  // Nuber of classes so far
347  virtual uint32_t getNumOutputs () const;
348 
349  // Number of Support Vectors
351  int32_t getNSV ();
352 
353  // Norm of the parameters vector
355  float64_t computeW2 ();
356 
357  // Compute Dual objective value
359  float64_t getDual ();
360 
366 
368  virtual const char* get_name() const { return "LaRank"; }
369 
373  void set_batch_mode(bool enable) { batch_mode=enable; };
375  bool get_batch_mode() { return batch_mode; };
379  void set_tau(float64_t t) { tau=t; };
383  float64_t get_tau() { return tau; };
384 
385  protected:
387  bool train_machine(CFeatures* data);
388 
389  private:
390  /*
391  ** MAIN DARK OPTIMIZATION PROCESSES
392  */
393 
394  // Hash Table used to store the different outputs
396  typedef std_hash_map < int32_t, LaRankOutput > outputhash_t; // class index -> LaRankOutput
397 
399  outputhash_t outputs;
400 
401  LaRankOutput *getOutput (int32_t index);
402 
403  //
404  LaRankPatterns patterns;
405 
406  // Parameters
407  int32_t nb_seen_examples;
408  int32_t nb_removed;
409 
410  // Numbers of each operation performed so far
411  int32_t n_pro;
412  int32_t n_rep;
413  int32_t n_opt;
414 
415  // Running estimates for each operations
416  float64_t w_pro;
417  float64_t w_rep;
418  float64_t w_opt;
419 
420  int32_t y0;
421  float64_t m_dual;
422 
423  struct outputgradient_t
424  {
425  outputgradient_t (int32_t result_output, float64_t result_gradient)
426  : output (result_output), gradient (result_gradient) {}
427  outputgradient_t ()
428  : output (0), gradient (0) {}
429 
430  int32_t output;
431  float64_t gradient;
432 
433  bool operator < (const outputgradient_t & og) const
434  {
435  return gradient > og.gradient;
436  }
437  };
438 
439  //3 types of operations in LaRank
440  enum process_type
441  {
442  processNew,
443  processOld,
444  processOptimize
445  };
446 
447  struct process_return_t
448  {
449  process_return_t (float64_t dual, int32_t yprediction)
450  : dual_increase (dual), ypred (yprediction) {}
451  process_return_t () {}
452  float64_t dual_increase;
453  int32_t ypred;
454  };
455 
456  // IMPORTANT Main SMO optimization step
457  process_return_t process (const LaRankPattern & pattern, process_type ptype);
458 
459  // ProcessOld
460  float64_t reprocess ();
461 
462  // Optimize
463  float64_t optimize ();
464 
465  // remove patterns and return the number of patterns that were removed
466  uint32_t cleanup ();
467 
468  protected:
469 
471  std_hash_set < int32_t >classes;
472 
474  inline uint32_t class_count () const
475  {
476  return classes.size ();
477  }
478 
481 
483  int32_t nb_train;
485  int64_t cache;
488 
490  int32_t step;
491  };
492 }
493 #endif // LARANK_H

SHOGUN Machine Learning Toolbox - Documentation