SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
KNN.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006 Christian Gehl
8  * Written (W) 1999-2009 Soeren Sonnenburg
9  * Written (W) 2011 Sergey Lisitsyn
10  * Written (W) 2012 Fernando José Iglesias García, cover tree support
11  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
12  */
13 
14 #ifndef _KNN_H__
15 #define _KNN_H__
16 
17 #include <shogun/lib/config.h>
18 
19 #include <shogun/lib/common.h>
20 #include <shogun/io/SGIO.h>
30 
31 namespace shogun
32 {
34  {
38 #ifdef HAVE_CXX11
39  KNN_LSH
40 #endif
41  };
42 
43 class CDistanceMachine;
44 
74 class CKNN : public CDistanceMachine
75 {
76  public:
78 
79 
80  CKNN();
81 
88  CKNN(int32_t k, CDistance* d, CLabels* trainlab, KNN_SOLVER knn_solver=KNN_BRUTE);
89 
90  virtual ~CKNN();
91 
96  virtual EMachineType get_classifier_type() { return CT_KNN; }
97 
108 
114  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
115 
117  virtual float64_t apply_one(int32_t vec_idx)
118  {
119  SG_ERROR("for performance reasons use apply() instead of apply(int32_t vec_idx)\n")
120  return 0;
121  }
122 
127 
133  virtual bool load(FILE* srcfile);
134 
140  virtual bool save(FILE* dstfile);
141 
146  inline void set_k(int32_t k)
147  {
148  ASSERT(k>0)
149  m_k=k;
150  }
151 
156  inline int32_t get_k()
157  {
158  return m_k;
159  }
160 
164  inline void set_q(float64_t q)
165  {
166  ASSERT(q<=1.0 && q>0.0)
167  m_q = q;
168  }
169 
173  inline float64_t get_q() { return m_q; }
174 
178  inline int32_t get_leaf_size() const {return m_leaf_size; }
179 
183  inline void set_leaf_size(int32_t leaf_size)
184  {
185  m_leaf_size = leaf_size;
186  }
187 
189  virtual const char* get_name() const { return "KNN"; }
190 
195  {
196  return m_knn_solver;
197  }
198 
203  inline void set_knn_solver_type(KNN_SOLVER knn_solver)
204  {
205  m_knn_solver = knn_solver;
206  }
207 
208 #ifdef HAVE_CXX11
209 
213  inline void set_lsh_parameters(int32_t l, int32_t t)
214  {
215  m_lsh_l = l;
216  m_lsh_t = t;
217  }
218 #endif
219 
220  protected:
225  virtual void store_model_features();
226 
230  virtual CMulticlassLabels* classify_NN();
231 
235  void init_distance(CFeatures* data);
236 
245  virtual bool train_machine(CFeatures* data=NULL);
246 
247  private:
248  void init();
249 
262  int32_t choose_class(float64_t* classes, int32_t* train_lab);
263 
276  void choose_class_for_multiple_k(int32_t* output, int32_t* classes, int32_t* train_lab, int32_t step);
277 
281  void init_solver(KNN_SOLVER knn_solver);
282 
283  protected:
285  int32_t m_k;
286 
289 
291  int32_t m_num_classes;
292 
294  int32_t m_min_label;
295 
298 
301 
303 
304  int32_t m_leaf_size;
305 
306 #ifdef HAVE_CXX11
307  /* Number of hash tables for LSH */
308  int32_t m_lsh_l;
309 
310  /* Number of probes per query for LSH */
311  int32_t m_lsh_t;
312 #endif
313 };
314 
315 }
316 #endif
EMachineType
Definition: Machine.h:33
virtual void store_model_features()
Definition: KNN.cpp:297
virtual bool save(FILE *dstfile)
Definition: KNN.cpp:290
virtual EMachineType get_classifier_type()
Definition: KNN.h:96
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
void init_distance(CFeatures *data)
Definition: KNN.cpp:269
KNN_SOLVER m_knn_solver
Definition: KNN.h:302
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
float64_t get_q()
Definition: KNN.h:173
SGMatrix< int32_t > classify_for_multiple_k()
Definition: KNN.cpp:242
KNN_SOLVER
Definition: KNN.h:33
#define SG_ERROR(...)
Definition: SGIO.h:128
int32_t get_k()
Definition: KNN.h:156
int32_t m_min_label
smallest label, i.e. -1
Definition: KNN.h:294
virtual bool train_machine(CFeatures *data=NULL)
Definition: KNN.cpp:77
void set_q(float64_t q)
Definition: KNN.h:164
SGMatrix< index_t > nearest_neighbors()
Definition: KNN.cpp:114
A generic DistanceMachine interface.
virtual bool load(FILE *srcfile)
Definition: KNN.cpp:283
int32_t m_num_classes
number of classes (i.e. number of values labels can take)
Definition: KNN.h:291
Multiclass Labels for multi-class classification.
void set_leaf_size(int32_t leaf_size)
Definition: KNN.h:183
int32_t m_k
the k parameter in KNN
Definition: KNN.h:285
#define ASSERT(x)
Definition: SGIO.h:200
#define MACHINE_PROBLEM_TYPE(PT)
Definition: Machine.h:120
double float64_t
Definition: common.h:60
int32_t get_leaf_size() const
Definition: KNN.h:178
Class KNN, an implementation of the standard k-nearest neigbor classifier.
Definition: KNN.h:74
KNN_SOLVER get_knn_solver_type()
Definition: KNN.h:194
float64_t m_q
parameter q of rank weighting
Definition: KNN.h:288
SGVector< int32_t > m_train_labels
Definition: KNN.h:297
void set_k(int32_t k)
Definition: KNN.h:146
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual const char * get_name() const
Definition: KNN.h:189
virtual ~CKNN()
Definition: KNN.cpp:73
CKNNSolver * solver
Solver for KNN.
Definition: KNN.h:300
void set_knn_solver_type(KNN_SOLVER knn_solver)
Definition: KNN.h:203
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual CMulticlassLabels * classify_NN()
Definition: KNN.cpp:194
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: KNN.cpp:160
virtual float64_t apply_one(int32_t vec_idx)
get output for example "vec_idx"
Definition: KNN.h:117
int32_t m_leaf_size
Definition: KNN.h:304

SHOGUN Machine Learning Toolbox - Documentation