SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
KNN.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006 Christian Gehl
8  * Written (W) 1999-2009 Soeren Sonnenburg
9  * Written (W) 2011 Sergey Lisitsyn
10  * Written (W) 2012 Fernando José Iglesias García, cover tree support
11  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
12  */
13 
14 #ifndef _KNN_H__
15 #define _KNN_H__
16 
17 #include <shogun/lib/config.h>
18 
19 #include <stdio.h>
20 #include <shogun/lib/common.h>
21 #include <shogun/io/SGIO.h>
25 
26 namespace shogun
27 {
28 
29 class CDistanceMachine;
30 
57 class CKNN : public CDistanceMachine
58 {
59  public:
61 
62 
63  CKNN();
64 
71  CKNN(int32_t k, CDistance* d, CLabels* trainlab);
72  virtual ~CKNN();
73 
78  virtual EMachineType get_classifier_type() { return CT_KNN; }
79 
90 
96  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
97 
99  virtual float64_t apply_one(int32_t vec_idx)
100  {
101  SG_ERROR("for performance reasons use apply() instead of apply(int32_t vec_idx)\n")
102  return 0;
103  }
104 
109 
115  virtual bool load(FILE* srcfile);
116 
122  virtual bool save(FILE* dstfile);
123 
128  inline void set_k(int32_t k)
129  {
130  ASSERT(k>0)
131  m_k=k;
132  }
133 
138  inline int32_t get_k()
139  {
140  return m_k;
141  }
142 
146  inline void set_q(float64_t q)
147  {
148  ASSERT(q<=1.0 && q>0.0)
149  m_q = q;
150  }
151 
155  inline float64_t get_q() { return m_q; }
156 
160  inline void set_use_covertree(bool use_covertree)
161  {
162  m_use_covertree = use_covertree;
163  }
164 
168  inline bool get_use_covertree() const { return m_use_covertree; }
169 
171  virtual const char* get_name() const { return "KNN"; }
172 
173  protected:
178  virtual void store_model_features();
179 
183  virtual CMulticlassLabels* classify_NN();
184 
188  void init_distance(CFeatures* data);
189 
198  virtual bool train_machine(CFeatures* data=NULL);
199 
200  private:
201  void init();
202 
215  int32_t choose_class(float64_t* classes, int32_t* train_lab);
216 
229  void choose_class_for_multiple_k(int32_t* output, int32_t* classes, int32_t* train_lab, int32_t step);
230 
231  protected:
233  int32_t m_k;
234 
237 
240 
242  int32_t m_num_classes;
243 
245  int32_t m_min_label;
246 
249 };
250 
251 }
252 #endif

SHOGUN Machine Learning Toolbox - Documentation