SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
LibSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
12 #include <shogun/io/SGIO.h>
14 
15 using namespace shogun;
16 
18 : CSVM(), solver_type(LIBSVM_C_SVC)
19 {
20  register_params();
21 }
22 
23 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
24 : CSVM(), solver_type(st)
25 {
26  register_params();
27 }
28 
29 
30 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab, LIBSVM_SOLVER_TYPE st)
31 : CSVM(C, k, lab), solver_type(st)
32 {
33  register_params();
34 }
35 
37 {
38 }
39 
40 void CLibSVM::register_params()
41 {
42  SG_ADD((machine_int_t*) &solver_type, "libsvm_solver_type", "LibSVM Solver type", MS_NOT_AVAILABLE);
43 }
44 
46 {
47  svm_problem problem;
48  svm_parameter param;
49  struct svm_model* model = nullptr;
50 
51  struct svm_node* x_space;
52 
55 
56  if (data)
57  {
58  if (m_labels->get_num_labels() != data->get_num_vectors())
59  {
60  SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
61  " not match number of labels (%d)\n", get_name(),
63  }
64  kernel->init(data, data);
65  }
66 
67  problem.l=m_labels->get_num_labels();
68  SG_INFO("%d trainlabels\n", problem.l)
69 
70  // set linear term
71  if (m_linear_term.vlen>0)
72  {
74  SG_ERROR("Number of training vectors does not match length of linear term\n")
75 
76  // set with linear term from base class
77  problem.pv = get_linear_term_array();
78  }
79  else
80  {
81  // fill with minus ones
82  problem.pv = SG_MALLOC(float64_t, problem.l);
83 
84  for (int i=0; i!=problem.l; i++)
85  problem.pv[i] = -1.0;
86  }
87 
88  problem.y=SG_MALLOC(float64_t, problem.l);
89  problem.x=SG_MALLOC(struct svm_node*, problem.l);
90  problem.C=SG_MALLOC(float64_t, problem.l);
91 
92  x_space=SG_MALLOC(struct svm_node, 2*problem.l);
93 
94  for (int32_t i=0; i<problem.l; i++)
95  {
96  problem.y[i]=((CBinaryLabels*) m_labels)->get_label(i);
97  problem.x[i]=&x_space[2*i];
98  x_space[2*i].index=i;
99  x_space[2*i+1].index=-1;
100  }
101 
102  int32_t weights_label[2]={-1,+1};
103  float64_t weights[2]={1.0,get_C2()/get_C1()};
104 
106  ASSERT(kernel->get_num_vec_lhs()==problem.l)
107 
108  switch (solver_type)
109  {
110  case LIBSVM_C_SVC:
111  param.svm_type=C_SVC;
112  break;
113  case LIBSVM_NU_SVC:
114  param.svm_type=NU_SVC;
115  break;
116  default:
117  SG_ERROR("%s::train_machine(): Unknown solver type!\n", get_name());
118  break;
119  }
120 
121  param.kernel_type = LINEAR;
122  param.degree = 3;
123  param.gamma = 0; // 1/k
124  param.coef0 = 0;
125  param.nu = get_nu();
126  param.kernel=kernel;
127  param.cache_size = kernel->get_cache_size();
128  param.max_train_time = m_max_train_time;
129  param.C = get_C1();
130  param.eps = epsilon;
131  param.p = 0.1;
132  param.shrinking = 1;
133  param.nr_weight = 2;
134  param.weight_label = weights_label;
135  param.weight = weights;
136  param.use_bias = get_bias_enabled();
137 
138  const char* error_msg = svm_check_parameter(&problem, &param);
139 
140  if(error_msg)
141  SG_ERROR("Error: %s\n",error_msg)
142 
143  model = svm_train(&problem, &param);
144 
145  if (model)
146  {
147  ASSERT(model->nr_class==2)
148  ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]))
149 
150  int32_t num_sv=model->l;
151 
152  create_new_model(num_sv);
153  CSVM::set_objective(model->objective);
154 
155  float64_t sgn=model->label[0];
156 
157  set_bias(-sgn*model->rho[0]);
158 
159  for (int32_t i=0; i<num_sv; i++)
160  {
161  set_support_vector(i, (model->SV[i])->index);
162  set_alpha(i, sgn*model->sv_coef[0][i]);
163  }
164 
165  SG_FREE(problem.x);
166  SG_FREE(problem.y);
167  SG_FREE(problem.pv);
168  SG_FREE(problem.C);
169 
170 
171  SG_FREE(x_space);
172 
173  svm_destroy_model(model);
174  model=NULL;
175  return true;
176  }
177  else
178  return false;
179 }
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:96
#define SG_INFO(...)
Definition: SGIO.h:117
virtual ELabelType get_label_type() const =0
binary labels +1/-1
Definition: LabelTypes.h:18
float64_t get_C2()
Definition: SVM.h:167
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
virtual int32_t get_num_vectors() const =0
float64_t m_max_train_time
Definition: Machine.h:362
CLabels * m_labels
Definition: Machine.h:365
#define SG_ERROR(...)
Definition: SGIO.h:128
float64_t get_nu()
Definition: SVM.h:155
float64_t epsilon
Definition: SVM.h:266
virtual int32_t get_num_vec_lhs()
int32_t cache_size
cache_size in MB
LIBSVM_SOLVER_TYPE solver_type
Definition: LibSVM.h:79
float64_t get_C1()
Definition: SVM.h:161
virtual float64_t * get_linear_term_array()
Definition: SVM.cpp:297
SGVector< float64_t > m_linear_term
Definition: SVM.h:261
index_t vlen
Definition: SGVector.h:545
#define ASSERT(x)
Definition: SGIO.h:200
void set_bias(float64_t bias)
double float64_t
Definition: common.h:60
bool set_alpha(int32_t idx, float64_t val)
void set_objective(float64_t v)
Definition: SVM.h:209
virtual bool train_machine(CFeatures *data=NULL)
Definition: LibSVM.cpp:45
bool set_support_vector(int32_t idx, int32_t val)
virtual const char * get_name() const
Definition: LibSVM.h:61
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
int machine_int_t
Definition: common.h:69
The class Features is the base class of all feature objects.
Definition: Features.h:68
A generic Support Vector Machine Interface.
Definition: SVM.h:49
The Kernel base class.
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
int32_t get_cache_size()
virtual ~CLibSVM()
Definition: LibSVM.cpp:36
#define SG_ADD(...)
Definition: SGObject.h:94
virtual bool has_features()
bool create_new_model(int32_t num)

SHOGUN Machine Learning Toolbox - Documentation