SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelaxedTree.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #ifndef RELAXEDTREE_H__
12 #define RELAXEDTREE_H__
13 
14 #include <utility>
15 #include <vector>
16 
17 #include <shogun/lib/config.h>
18 
23 
24 namespace shogun
25 {
26 
27 class CBaseMulticlassMachine;
28 
36 class CRelaxedTree: public CTreeMachine<RelaxedTreeNodeData>
37 {
38 public:
40  CRelaxedTree();
41 
43  virtual ~CRelaxedTree();
44 
46  virtual const char* get_name() const { return "RelaxedTree"; }
47 
49  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
50 
55  {
56  SG_REF(feats);
58  m_feats = feats;
59  }
60 
64  virtual void set_kernel(CKernel *kernel)
65  {
66  SG_REF(kernel);
68  m_kernel = kernel;
69  }
70 
75  virtual void set_labels(CLabels* lab)
76  {
77  CMulticlassLabels *mlab = dynamic_cast<CMulticlassLabels *>(lab);
78  REQUIRE(lab, "requires MulticlassLabes\n")
79 
82  }
83 
88  {
89  SG_REF(machine);
92  }
93 
98  {
99  m_svm_C = C;
100  }
105  {
106  return m_svm_C;
107  }
108 
113  {
115  }
120  {
121  return m_svm_epsilon;
122  }
123 
129  void set_A(float64_t A)
130  {
131  m_A = A;
132  }
136  float64_t get_A() const
137  {
138  return m_A;
139  }
140 
145  void set_B(int32_t B)
146  {
147  m_B = B;
148  }
152  int32_t get_B() const
153  {
154  return m_B;
155  }
156 
160  void set_max_num_iter(int32_t n_iter)
161  {
162  m_max_num_iter = n_iter;
163  }
167  int32_t get_max_num_iter() const
168  {
169  return m_max_num_iter;
170  }
171 
181  virtual bool train(CFeatures* data=NULL)
182  {
183  return CMachine::train(data);
184  }
185 
187  typedef std::pair<std::pair<int32_t, int32_t>, float64_t> entry_t;
188 protected:
195  float64_t apply_one(int32_t idx);
196 
203  virtual bool train_machine(CFeatures* data);
204 
206  bnode_t *train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes);
208  std::vector<entry_t> init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes);
211 
218 
220  void enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
222  void enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
223 
225  int32_t m_max_num_iter;
229  int32_t m_B;
241  int32_t m_num_classes;
242 };
243 
244 } /* shogun */
245 
246 #endif /* end of include guard: RELAXEDTREE_H__ */
247 

SHOGUN Machine Learning Toolbox - Documentation