SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
KNN.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006 Christian Gehl
8  * Written (W) 2006-2009 Soeren Sonnenburg
9  * Written (W) 2011 Sergey Lisitsyn
10  * Written (W) 2012 Fernando José Iglesias García, cover tree support
11  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
12  */
13 
14 #include <shogun/lib/Time.h>
15 #include <shogun/lib/Signal.h>
16 #include <shogun/multiclass/KNN.h>
17 #include <shogun/labels/Labels.h>
19 #include <shogun/base/Parameter.h>
20 
21 //#define DEBUG_KNN
22 
23 using namespace shogun;
24 
27 {
28  init();
29 }
30 
31 CKNN::CKNN(int32_t k, CDistance* d, CLabels* trainlab, KNN_SOLVER knn_solver)
33 {
34  init();
35 
36  m_k=k;
37 
38  ASSERT(d)
39  ASSERT(trainlab)
40 
41  set_distance(d);
42  set_labels(trainlab);
44  m_knn_solver=knn_solver;
45 }
46 
47 void CKNN::init()
48 {
49  /* do not store model features by default (CDistanceMachine::apply(...) is
50  * overwritten */
52 
53  m_k=3;
54  m_q=1.0;
55  m_num_classes=0;
56  m_leaf_size=1;
58  solver=NULL;
59 #ifdef HAVE_CXX11
60  m_lsh_l = 0;
61  m_lsh_t = 0;
62 #endif
63 
64  /* use the method classify_multiply_k to experiment with different values
65  * of k */
66  SG_ADD(&m_k, "m_k", "Parameter k", MS_NOT_AVAILABLE);
67  SG_ADD(&m_q, "m_q", "Parameter q", MS_AVAILABLE);
68  SG_ADD(&m_num_classes, "m_num_classes", "Number of classes", MS_NOT_AVAILABLE);
69  SG_ADD(&m_leaf_size, "m_leaf_size", "Leaf size for KDTree", MS_NOT_AVAILABLE);
70  SG_ADD((machine_int_t*) &m_knn_solver, "m_knn_solver", "Algorithm to solve knn", MS_NOT_AVAILABLE);
71 }
72 
74 {
75 }
76 
78 {
81 
82  if (data)
83  {
84  if (m_labels->get_num_labels() != data->get_num_vectors())
85  SG_ERROR("Number of training vectors does not match number of labels\n")
86  distance->init(data, data);
87  }
88 
89  SGVector<int32_t> lab=((CMulticlassLabels*) m_labels)->get_int_labels();
90  m_train_labels=lab.clone();
92 
93  int32_t max_class=m_train_labels[0];
94  int32_t min_class=m_train_labels[0];
95 
96  for (int32_t i=1; i<m_train_labels.vlen; i++)
97  {
98  max_class=CMath::max(max_class, m_train_labels[i]);
99  min_class=CMath::min(min_class, m_train_labels[i]);
100  }
101 
102  for (int32_t i=0; i<m_train_labels.vlen; i++)
103  m_train_labels[i]-=min_class;
104 
105  m_min_label=min_class;
106  m_num_classes=max_class-min_class+1;
107 
108  SG_INFO("m_num_classes: %d (%+d to %+d) num_train: %d\n", m_num_classes,
109  min_class, max_class, m_train_labels.vlen);
110 
111  return true;
112 }
113 
115 {
116  //number of examples to which kNN is applied
117  int32_t n=distance->get_num_vec_rhs();
118  //distances to train data
120  //indices to train data
122  //pre-allocation of the nearest neighbors
123  SGMatrix<index_t> NN(m_k, n);
124 
127 
128  //for each test example
129  for (int32_t i=0; i<n && (!CSignal::cancel_computations()); i++)
130  {
131  SG_PROGRESS(i, 0, n)
132 
133  //lhs idx 0..num train examples-1 (i.e., all train examples) and rhs idx i
134  distances_lhs(dists,0,m_train_labels.vlen-1,i);
135 
136  //fill in an array with 0..num train examples-1
137  for (int32_t j=0; j<m_train_labels.vlen; j++)
138  train_idxs[j]=j;
139 
140  //sort the distance vector between test example i and all train examples
141  CMath::qsort_index(dists.vector, train_idxs.vector, m_train_labels.vlen);
142 
143 #ifdef DEBUG_KNN
144  SG_PRINT("\nQuick sort query %d\n", i)
145  for (int32_t j=0; j<m_k; j++)
146  SG_PRINT("%d ", train_idxs[j])
147  SG_PRINT("\n")
148 #endif
149 
150  //fill in the output the indices of the nearest neighbors
151  for (int32_t j=0; j<m_k; j++)
152  NN(j,i) = train_idxs[j];
153  }
154 
156 
157  return NN;
158 }
159 
161 {
162  if (data)
163  init_distance(data);
164 
165  //redirecting to fast (without sorting) classify if k==1
166  if (m_k == 1)
167  return classify_NN();
168 
172 
173  int32_t num_lab=distance->get_num_vec_rhs();
174  ASSERT(m_k<=distance->get_num_vec_lhs())
175 
176  //labels of the k nearest neighbors
177  SGVector<int32_t> train_lab(m_k);
178 
179  SG_INFO("%d test examples\n", num_lab)
181 
182  //histogram of classes and returned output
184 
185  init_solver(m_knn_solver);
186 
187  CMulticlassLabels* output = solver->classify_objects(distance, num_lab, train_lab, classes);
188 
189  SG_UNREF(solver);
190 
191  return output;
192 }
193 
195 {
198 
199  int32_t num_lab = distance->get_num_vec_rhs();
200  ASSERT(num_lab)
201 
202  CMulticlassLabels* output = new CMulticlassLabels(num_lab);
204 
205  SG_INFO("%d test examples\n", num_lab)
207 
209 
210  // for each test example
211  for (int32_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
212  {
213  SG_PROGRESS(i,0,num_lab)
214 
215  // get distances from i-th test example to 0..num_m_train_labels-1 train examples
216  distances_lhs(distances,0,m_train_labels.vlen-1,i);
217  int32_t j;
218 
219  // assuming 0th train examples as nearest to i-th test example
220  int32_t out_idx = 0;
221  float64_t min_dist = distances.vector[0];
222 
223  // searching for nearest neighbor by comparing distances
224  for (j=0; j<m_train_labels.vlen; j++)
225  {
226  if (distances.vector[j]<min_dist)
227  {
228  min_dist = distances.vector[j];
229  out_idx = j;
230  }
231  }
232 
233  // label i-th test example with label of nearest neighbor with out_idx index
234  output->set_label(i,m_train_labels.vector[out_idx]+m_min_label);
235  }
236 
238 
239  return output;
240 }
241 
243 {
247 
248  int32_t num_lab=distance->get_num_vec_rhs();
249  ASSERT(m_k<=num_lab)
250 
251  //working buffer of m_train_labels
252  SGVector<int32_t> train_lab(m_k);
253 
254  //histogram of classes and returned output
256 
257  SG_INFO("%d test examples\n", num_lab)
259 
260  init_solver(m_knn_solver);
261 
262  SGVector<int32_t> output = solver->classify_objects_k(distance, num_lab, train_lab, classes);
263 
264  SG_UNREF(solver);
265 
266  return SGMatrix<int32_t>(output,num_lab,m_k);
267 }
268 
270 {
271  if (!distance)
272  SG_ERROR("No distance assigned!\n")
273  CFeatures* lhs=distance->get_lhs();
274  if (!lhs || !lhs->get_num_vectors())
275  {
276  SG_UNREF(lhs);
277  SG_ERROR("No vectors on left hand side\n")
278  }
279  distance->init(lhs, data);
280  SG_UNREF(lhs);
281 }
282 
283 bool CKNN::load(FILE* srcfile)
284 {
287  return false;
288 }
289 
290 bool CKNN::save(FILE* dstfile)
291 {
294  return false;
295 }
296 
298 {
299  CFeatures* d_lhs=distance->get_lhs();
300  CFeatures* d_rhs=distance->get_rhs();
301 
302  /* copy lhs of underlying distance */
303  distance->init(d_lhs->duplicate(), d_rhs);
304 
305  SG_UNREF(d_lhs);
306  SG_UNREF(d_rhs);
307 }
308 
309 void CKNN::init_solver(KNN_SOLVER knn_solver)
310 {
311  switch (knn_solver)
312  {
313  case KNN_BRUTE:
314  {
317  SG_REF(solver);
318  break;
319  }
320  case KNN_KDTREE:
321  {
323  SG_REF(solver);
324  break;
325  }
326  case KNN_COVER_TREE:
327  {
329  SG_REF(solver);
330  break;
331  }
332 #ifdef HAVE_CXX11
333  case KNN_LSH:
334  {
335  solver = new CLSHKNNSolver(m_k, m_q, m_num_classes, m_min_label, m_train_labels, m_lsh_l, m_lsh_t);
336  SG_REF(solver);
337  break;
338  }
339 #endif
340  }
341 }
virtual void store_model_features()
Definition: KNN.cpp:297
#define SG_INFO(...)
Definition: SGIO.h:117
#define SG_RESET_LOCALE
Definition: SGIO.h:85
virtual bool save(FILE *dstfile)
Definition: KNN.cpp:290
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
virtual void reset_precompute()
Definition: Distance.h:150
#define SG_PROGRESS(...)
Definition: SGIO.h:141
void init_distance(CFeatures *data)
Definition: KNN.cpp:269
KNN_SOLVER m_knn_solver
Definition: KNN.h:302
CFeatures * get_lhs()
Definition: Distance.h:218
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
static void qsort_index(T1 *output, T2 *index, uint32_t size)
Definition: Math.h:2223
SGMatrix< int32_t > classify_for_multiple_k()
Definition: KNN.cpp:242
KNN_SOLVER
Definition: KNN.h:33
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:365
#define SG_ERROR(...)
Definition: SGIO.h:128
CFeatures * get_rhs()
Definition: Distance.h:224
virtual CFeatures * duplicate() const =0
int32_t m_min_label
smallest label, i.e. -1
Definition: KNN.h:294
virtual bool train_machine(CFeatures *data=NULL)
Definition: KNN.cpp:77
SGMatrix< index_t > nearest_neighbors()
Definition: KNN.cpp:114
#define SG_REF(x)
Definition: SGObject.h:52
#define SG_SET_LOCALE_C
Definition: SGIO.h:84
A generic DistanceMachine interface.
bool set_label(int32_t idx, float64_t label)
virtual bool load(FILE *srcfile)
Definition: KNN.cpp:283
int32_t m_num_classes
number of classes (i.e. number of values labels can take)
Definition: KNN.h:291
Multiclass Labels for multi-class classification.
index_t vlen
Definition: SGVector.h:545
int32_t m_k
the k parameter in KNN
Definition: KNN.h:285
#define SG_PRINT(...)
Definition: SGIO.h:136
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:107
#define ASSERT(x)
Definition: SGIO.h:200
void distances_lhs(SGVector< float64_t > &result, int32_t idx_a1, int32_t idx_a2, int32_t idx_b)
static void clear_cancel()
Definition: Signal.cpp:126
double float64_t
Definition: common.h:60
static T max(T a, T b)
Definition: Math.h:164
virtual int32_t get_num_vec_rhs()
Definition: Distance.h:315
static bool cancel_computations()
Definition: Signal.h:111
float64_t m_q
parameter q of rank weighting
Definition: KNN.h:288
SGVector< int32_t > m_train_labels
Definition: KNN.h:297
#define SG_UNREF(x)
Definition: SGObject.h:53
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual ~CKNN()
Definition: KNN.cpp:73
CKNNSolver * solver
Solver for KNN.
Definition: KNN.h:300
int machine_int_t
Definition: common.h:69
The class Features is the base class of all feature objects.
Definition: Features.h:68
static T min(T a, T b)
Definition: Math.h:153
void set_distance(CDistance *d)
virtual void precompute_lhs()
Definition: Distance.h:143
SGVector< T > clone() const
Definition: SGVector.cpp:247
virtual CMulticlassLabels * classify_NN()
Definition: KNN.cpp:194
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: KNN.cpp:160
virtual SGVector< int32_t > classify_objects_k(CDistance *d, const int32_t num_lab, SGVector< int32_t > &train_lab, SGVector< int32_t > &classes) const =0
virtual void precompute_rhs()
Definition: Distance.h:135
virtual CMulticlassLabels * classify_objects(CDistance *d, const int32_t num_lab, SGVector< int32_t > &train_lab, SGVector< float64_t > &classes) const =0
#define SG_ADD(...)
Definition: SGObject.h:94
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Distance.cpp:55
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
int32_t m_leaf_size
Definition: KNN.h:304

SHOGUN Machine Learning Toolbox - Documentation