SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
NearestCentroid.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Philippe Tillet
8  */
9 
14 
15 
16 
17 namespace shogun{
18 
20  {
21  init();
22  }
23 
25  {
26  init();
27  ASSERT(d)
28  ASSERT(trainlab)
29  set_distance(d);
30  set_labels(trainlab);
31  }
32 
34  {
35  if(m_is_trained)
37  else
38  delete m_centroids;
39  }
40 
41  void CNearestCentroid::init()
42  {
43  m_shrinking=0;
44  m_is_trained=false;
46  }
47 
48 
50  {
54  ASSERT( data->get_feature_class() == C_DENSE)
55  if (data)
56  {
57  if (m_labels->get_num_labels() != data->get_num_vectors())
58  SG_ERROR("Number of training vectors does not match number of labels\n")
59  distance->init(data, data);
60  }
61  else
62  {
63  data = distance->get_lhs();
64  }
65  int32_t num_vectors = data->get_num_vectors();
66  int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
67  int32_t num_feats = ((CDenseFeatures<float64_t>*) data)->get_num_features();
68  SGMatrix<float64_t> centroids(num_feats,num_classes);
69  centroids.zero();
70 
71  m_centroids->set_num_features(num_feats);
72  m_centroids->set_num_vectors(num_classes);
73 
74  int64_t* num_per_class = new int64_t[num_classes];
75  for (int32_t i=0 ; i<num_classes ; i++)
76  {
77  num_per_class[i]=0;
78  }
79 
80  for (int32_t idx=0 ; idx<num_vectors ; idx++)
81  {
82  int32_t current_len;
83  bool current_free;
84  int32_t current_class = ((CMulticlassLabels*) m_labels)->get_label(idx);
85  float64_t* target = centroids.matrix + num_feats*current_class;
86  float64_t* current = ((CDenseFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free);
87  SGVector<float64_t>::add(target,1.0,target,1.0,current,current_len);
88  num_per_class[current_class]++;
89  ((CDenseFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free);
90  }
91 
92 
93  for (int32_t i=0 ; i<num_classes ; i++)
94  {
95  float64_t* target = centroids.matrix + num_feats*i;
96  int32_t total = num_per_class[i];
97  float64_t scale = 0;
98  if(total>1)
99  scale = 1.0/((float64_t)(total-1));
100  else
101  scale = 1.0/(float64_t)total;
102 
103  SGVector<float64_t>::scale_vector(scale,target,num_feats);
104  }
105 
107  m_centroids->set_feature_matrix(centroids);
108 
109 
110  m_is_trained=true;
112 
113  SG_FREE(num_per_class);
114 
115  return true;
116  }
117 
118 }
bool m_is_trained
Tells if the classifier has been trained or not.
virtual ELabelType get_label_type() const =0
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:81
CFeatures * get_lhs()
Definition: Distance.h:195
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
multi-class labels 0,1,...
Definition: LabelTypes.h:20
void set_feature_matrix(SGMatrix< ST > matrix)
virtual void remove_lhs()
takes all necessary steps if the lhs is removed from distance matrix
Definition: Distance.cpp:130
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:361
#define SG_ERROR(...)
Definition: SGIO.h:129
CFeatures * get_rhs()
Definition: Distance.h:201
float64_t m_shrinking
Shrinking parameter.
A generic DistanceMachine interface.
static void scale_vector(T alpha, T *vec, int32_t len)
Scale vector inplace.
Definition: SGVector.cpp:822
Multiclass Labels for multi-class classification.
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
void set_num_vectors(int32_t num)
virtual EFeatureClass get_feature_class() const =0
void set_num_features(int32_t num)
CDenseFeatures< float64_t > * m_centroids
The centroids of the trained features.
virtual bool train_machine(CFeatures *data=NULL)
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
void set_distance(CDistance *d)
void scale(Matrix A, Matrix B, typename Matrix::Scalar alpha)
Definition: Core.h:93
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Distance.cpp:78
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:73
void add(const SGVector< T > x)
Definition: SGVector.cpp:281

SHOGUN Machine Learning Toolbox - Documentation