SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SGDQN.h
Go to the documentation of this file.
1 #ifndef _SGDQN_H___
2 #define _SGDQN_H___
3 
4 /*
5  SVM with Quasi-Newton stochastic gradient
6  Copyright (C) 2009- Antoine Bordes
7 
8  This program is free software; you can redistribute it and/or
9  modify it under the terms of the GNU Lesser General Public
10  License as published by the Free Software Foundation; either
11  version 2.1 of the License, or (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU Lesser General Public
19  License along with this library; if not, write to the Free Software
20  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21 
22  Shogun adjustments (w) 2011 Siddharth Kherada
23 */
24 
25 #include <shogun/lib/config.h>
26 
27 #include <shogun/lib/common.h>
30 #include <shogun/labels/Labels.h>
32 
33 namespace shogun
34 {
36 class CSGDQN : public CLinearMachine
37 {
38  public:
39 
42 
44  CSGDQN();
45 
50  CSGDQN(float64_t C);
51 
58  CSGDQN(
59  float64_t C, CDotFeatures* traindat,
60  CLabels* trainlab);
61 
62  virtual ~CSGDQN();
63 
69 
78  virtual bool train(CFeatures* data=NULL);
79 
86  inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
87 
92  inline float64_t get_C1() { return C1; }
93 
98  inline float64_t get_C2() { return C2; }
99 
104  inline void set_epochs(int32_t e) { epochs=e; }
105 
110  inline int32_t get_epochs() { return epochs; }
111 
113  void compute_ratio(float64_t* W,float64_t* W_1,float64_t* B,float64_t* dst,int32_t dim,float64_t regularizer_lambda,float64_t loss);
114 
116  void combine_and_clip(float64_t* Bc,float64_t* B,int32_t dim,float64_t c1,float64_t c2,float64_t v1,float64_t v2);
117 
122  void set_loss_function(CLossFunction* loss_func);
123 
128  inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; }
129 
131  virtual const char* get_name() const { return "SGDQN"; }
132 
133  protected:
135  void calibrate();
136 
137  private:
138  void init();
139 
140  private:
141  float64_t t;
142  float64_t C1;
143  float64_t C2;
144  int32_t epochs;
145  int32_t skip;
146  int32_t count;
147 
148  CLossFunction* loss;
149 };
150 }
151 #endif

SHOGUN Machine Learning Toolbox - Documentation