SHOGUN  5.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
18 #include <shogun/lib/List.h>
19 
20 using namespace shogun;
21 
23 {
24  init();
25 }
26 
28  CLabels* labels, CSplittingStrategy* splitting_strategy,
29  CEvaluation* evaluation_criterion, bool autolock) :
30  CMachineEvaluation(machine, features, labels, splitting_strategy,
31  evaluation_criterion, autolock)
32 {
33  init();
34 }
35 
37  CSplittingStrategy* splitting_strategy,
38  CEvaluation* evaluation_criterion, bool autolock) :
39  CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
40  autolock)
41 {
42  init();
43 }
44 
46 {
48 }
49 
50 void CCrossValidation::init()
51 {
52  m_num_runs=1;
53 
54  /* do reference counting for output objects */
55  m_xval_outputs=new CList(true);
56 
57  SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
59  SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
60  "classes for intermediade cross-validation results",
62 }
63 
65 {
66  SG_DEBUG("entering %s::evaluate()\n", get_name())
67 
68  REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
69  "attached\n", get_name());
70 
71  REQUIRE(m_features, "%s::evaluate() is only possible if features are "
72  "attached\n", get_name());
73 
74  REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
75  "attached\n", get_name());
76 
77  /* if for some reason the do_unlock_frag is set, unlock */
78  if (m_do_unlock)
79  {
81  m_do_unlock=false;
82  }
83 
84  /* set labels in any case (no locking needs this) */
86 
87  if (m_autolock)
88  {
89  /* if machine supports locking try to do so */
91  {
92  /* only lock if machine is not yet locked */
93  if (!m_machine->is_data_locked())
94  {
96  m_do_unlock=true;
97  }
98  }
99  else
100  {
101  SG_WARNING("%s does not support locking. Autolocking is skipped. "
102  "Set autolock flag to false to get rid of warning.\n",
103  m_machine->get_name());
104  }
105  }
106 
108 
109  /* evtl. update xvalidation output class */
112  while (current)
113  {
114  current->init_num_runs(m_num_runs);
116  current->init_expose_labels(m_labels);
117  current->post_init();
118  SG_UNREF(current);
119  current=(CCrossValidationOutput*)
121  }
122 
123  /* perform all the x-val runs */
124  SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs)
125  for (index_t i=0; i <m_num_runs; ++i)
126  {
127 
128  /* evtl. update xvalidation output class */
130  while (current)
131  {
132  current->update_run_index(i);
133  SG_UNREF(current);
134  current=(CCrossValidationOutput*)
136  }
137 
138  SG_DEBUG("entering cross-validation run %d \n", i)
139  results[i]=evaluate_one_run();
140  SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])
141  }
142 
143  /* construct evaluation result */
145  result->mean=CStatistics::mean(results);
146  if (m_num_runs>1)
147  result->std_dev=CStatistics::std_deviation(results);
148  else
149  result->std_dev=0;
150 
151  /* unlock machine if it was locked in this method */
153  {
155  m_do_unlock=false;
156  }
157 
158  SG_DEBUG("leaving %s::evaluate()\n", get_name())
159 
160  SG_REF(result);
161  return result;
162 }
163 
164 void CCrossValidation::set_num_runs(int32_t num_runs)
165 {
166  if (num_runs <1)
167  SG_ERROR("%d is an illegal number of repetitions\n", num_runs)
168 
169  m_num_runs=num_runs;
170 }
171 
173 {
174  SG_DEBUG("entering %s::evaluate_one_run()\n", get_name())
176 
177  SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets)
178 
179  /* build index sets */
181 
182  /* results array */
183  SGVector<float64_t> results(num_subsets);
184 
185  /* different behavior whether data is locked or not */
186  if (m_machine->is_data_locked())
187  {
188  SG_DEBUG("starting locked evaluation\n", get_name())
189  /* do actual cross-validation */
190  for (index_t i=0; i <num_subsets; ++i)
191  {
192  /* evtl. update xvalidation output class */
195  while (current)
196  {
197  current->update_fold_index(i);
198  SG_UNREF(current);
199  current=(CCrossValidationOutput*)
201  }
202 
203  /* index subset for training, will be freed below */
204  SGVector<index_t> inverse_subset_indices =
206 
207  /* train machine on training features */
208  m_machine->train_locked(inverse_subset_indices);
209 
210  /* feature subset for testing */
211  SGVector<index_t> subset_indices =
213 
214  /* evtl. update xvalidation output class */
216  while (current)
217  {
218  current->update_train_indices(inverse_subset_indices, "\t");
219  current->update_trained_machine(m_machine, "\t");
220  SG_UNREF(current);
221  current=(CCrossValidationOutput*)
223  }
224 
225  /* produce output for desired indices */
226  CLabels* result_labels=m_machine->apply_locked(subset_indices);
227  SG_REF(result_labels);
228 
229  /* set subset for testing labels */
230  m_labels->add_subset(subset_indices);
231 
232  /* evaluate against own labels */
233  m_evaluation_criterion->set_indices(subset_indices);
234  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
235 
236  /* evtl. update xvalidation output class */
238  while (current)
239  {
240  current->update_test_indices(subset_indices, "\t");
241  current->update_test_result(result_labels, "\t");
242  current->update_test_true_result(m_labels, "\t");
243  current->post_update_results();
244  current->update_evaluation_result(results[i], "\t");
245  SG_UNREF(current);
246  current=(CCrossValidationOutput*)
248  }
249 
250  /* remove subset to prevent side effects */
252 
253  /* clean up */
254  SG_UNREF(result_labels);
255 
256  SG_DEBUG("done locked evaluation\n", get_name())
257  }
258  }
259  else
260  {
261  SG_DEBUG("starting unlocked evaluation\n", get_name())
262  /* tell machine to store model internally
263  * (otherwise changing subset of features will kaboom the classifier) */
265 
266  /* do actual cross-validation */
267  for (index_t i=0; i <num_subsets; ++i)
268  {
269  /* evtl. update xvalidation output class */
272  while (current)
273  {
274  current->update_fold_index(i);
275  SG_UNREF(current);
276  current=(CCrossValidationOutput*)
278  }
279 
280  /* set feature subset for training */
281  SGVector<index_t> inverse_subset_indices=
283  m_features->add_subset(inverse_subset_indices);
284  for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
285  {
286  CPreprocessor* preprocessor = m_features->get_preprocessor(p);
287  preprocessor->init(m_features);
288  SG_UNREF(preprocessor);
289  }
290 
291  /* set label subset for training */
292  m_labels->add_subset(inverse_subset_indices);
293 
294  SG_DEBUG("training set %d:\n", i)
295  if (io->get_loglevel()==MSG_DEBUG)
296  {
297  SGVector<index_t>::display_vector(inverse_subset_indices.vector,
298  inverse_subset_indices.vlen, "training indices");
299  }
300 
301  /* train machine on training features and remove subset */
302  SG_DEBUG("starting training\n")
304  SG_DEBUG("finished training\n")
305 
306  /* evtl. update xvalidation output class */
308  while (current)
309  {
310  current->update_train_indices(inverse_subset_indices, "\t");
311  current->update_trained_machine(m_machine, "\t");
312  SG_UNREF(current);
313  current=(CCrossValidationOutput*)
315  }
316 
319 
320  /* set feature subset for testing (subset method that stores pointer) */
321  SGVector<index_t> subset_indices =
323  m_features->add_subset(subset_indices);
324 
325  /* set label subset for testing */
326  m_labels->add_subset(subset_indices);
327 
328  SG_DEBUG("test set %d:\n", i)
329  if (io->get_loglevel()==MSG_DEBUG)
330  {
332  subset_indices.vlen, "test indices");
333  }
334 
335  /* apply machine to test features and remove subset */
336  SG_DEBUG("starting evaluation\n")
337  SG_DEBUG("%p\n", m_features)
338  CLabels* result_labels=m_machine->apply(m_features);
339  SG_DEBUG("finished evaluation\n")
341  SG_REF(result_labels);
342 
343  /* evaluate */
344  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
345  SG_DEBUG("result on fold %d is %f\n", i, results[i])
346 
347  /* evtl. update xvalidation output class */
349  while (current)
350  {
351  current->update_test_indices(subset_indices, "\t");
352  current->update_test_result(result_labels, "\t");
353  current->update_test_true_result(m_labels, "\t");
354  current->post_update_results();
355  current->update_evaluation_result(results[i], "\t");
356  SG_UNREF(current);
357  current=(CCrossValidationOutput*)
359  }
360 
361  /* clean up, remove subsets */
362  SG_UNREF(result_labels);
364  }
365 
366  SG_DEBUG("done unlocked evaluation\n", get_name())
367  }
368 
369  /* build arithmetic mean of results */
370  float64_t mean=CStatistics::mean(results);
371 
372  SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name())
373  return mean;
374 }
375 
377  CCrossValidationOutput* cross_validation_output)
378 {
379  m_xval_outputs->append_element(cross_validation_output);
380 }
virtual void update_fold_index(index_t fold_index, const char *prefix="")
virtual void build_subsets()=0
virtual void update_train_indices(SGVector< index_t > indices, const char *prefix="")
virtual bool init(CFeatures *features)=0
CSGObject * get_next_element()
Definition: List.h:185
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:187
int32_t index_t
Definition: common.h:62
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
static float64_t std_deviation(SGVector< float64_t > values)
Definition: Statistics.cpp:120
virtual CEvaluationResult * evaluate()
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
virtual void update_test_true_result(CLabels *results, const char *prefix="")
Abstract base class for all splitting types. Takes a CLabels instance and generates a desired number ...
virtual void init_num_runs(index_t num_runs, const char *prefix="")
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual void update_test_indices(SGVector< index_t > indices, const char *prefix="")
CPreprocessor * get_preprocessor(int32_t num) const
Definition: Features.cpp:93
type to encapsulate the results of an evaluation run.
virtual const char * get_name() const
Definition: Machine.h:305
virtual bool train_locked(SGVector< index_t > indices)
Definition: Machine.h:239
#define SG_REF(x)
Definition: SGObject.h:54
void set_num_runs(int32_t num_runs)
A generic learning machine interface.
Definition: Machine.h:143
virtual void set_indices(SGVector< index_t > indices)
Definition: Evaluation.h:63
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:354
int32_t get_num_preprocessors() const
Definition: Features.cpp:155
virtual void update_trained_machine(CMachine *machine, const char *prefix="")
index_t vlen
Definition: SGVector.h:494
CSGObject * get_first_element()
Definition: List.h:151
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:107
Class for managing individual folds in cross-validation.
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
double float64_t
Definition: common.h:50
virtual void data_unlock()
Definition: Machine.cpp:143
virtual const char * get_name() const
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:112
virtual void remove_subset()
Definition: Labels.cpp:49
Abstract class that contains the result generated by the MachineEvaluation class. ...
Machine Evaluation is an abstract class that evaluates a machine according to some criterion...
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
SGVector< index_t > generate_subset_inverse(index_t subset_idx)
static floatmax_t mean(SGVector< T > vec)
Definition: Statistics.h:42
EMessageType get_loglevel() const
Definition: SGIO.cpp:285
virtual void update_test_result(CLabels *results, const char *prefix="")
virtual bool supports_locking() const
Definition: Machine.h:293
virtual float64_t evaluate_one_run()
#define SG_UNREF(x)
Definition: SGObject.h:55
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
SGVector< index_t > generate_subset_indices(index_t subset_idx)
virtual void remove_subset()
Definition: Features.cpp:322
virtual void update_evaluation_result(float64_t result, const char *prefix="")
The class Features is the base class of all feature objects.
Definition: Features.h:68
bool append_element(CSGObject *data)
Definition: List.h:331
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:39
Class Preprocessor defines a preprocessor interface.
Definition: Preprocessor.h:75
void add_cross_validation_output(CCrossValidationOutput *cross_validation_output)
#define SG_WARNING(...)
Definition: SGIO.h:128
#define SG_ADD(...)
Definition: SGObject.h:84
virtual void init_expose_labels(CLabels *labels)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
bool is_data_locked() const
Definition: Machine.h:296
virtual void init_num_folds(index_t num_folds, const char *prefix="")
virtual void update_run_index(index_t run_index, const char *prefix="")
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:40
CSplittingStrategy * m_splitting_strategy
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
Class List implements a doubly connected list for low-level-objects.
Definition: List.h:84
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:152

SHOGUN Machine Learning Toolbox - Documentation