SHOGUN  6.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
18 #include <shogun/lib/List.h>
19 
20 using namespace shogun;
21 
23 {
24  init();
25 }
26 
28  CLabels* labels, CSplittingStrategy* splitting_strategy,
29  CEvaluation* evaluation_criterion, bool autolock) :
30  CMachineEvaluation(machine, features, labels, splitting_strategy,
31  evaluation_criterion, autolock)
32 {
33  init();
34 }
35 
37  CSplittingStrategy* splitting_strategy,
38  CEvaluation* evaluation_criterion, bool autolock) :
39  CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
40  autolock)
41 {
42  init();
43 }
44 
46 {
48 }
49 
50 void CCrossValidation::init()
51 {
52  m_num_runs=1;
53 
54  /* do reference counting for output objects */
55  m_xval_outputs=new CList(true);
56 
57  SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
59  SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
60  "classes for intermediade cross-validation results",
62 }
63 
65 {
66  SG_DEBUG("entering %s::evaluate()\n", get_name())
67 
68  REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
69  "attached\n", get_name());
70 
71  REQUIRE(m_features, "%s::evaluate() is only possible if features are "
72  "attached\n", get_name());
73 
74  REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
75  "attached\n", get_name());
76 
77  /* if for some reason the do_unlock_frag is set, unlock */
78  if (m_do_unlock)
79  {
81  m_do_unlock=false;
82  }
83 
84  /* set labels in any case (no locking needs this) */
86 
87  if (m_autolock)
88  {
89  /* if machine supports locking try to do so */
91  {
92  /* only lock if machine is not yet locked */
93  if (!m_machine->is_data_locked())
94  {
96  m_do_unlock=true;
97  }
98  }
99  else
100  {
101  SG_WARNING("%s does not support locking. Autolocking is skipped. "
102  "Set autolock flag to false to get rid of warning.\n",
103  m_machine->get_name());
104  }
105  }
106 
108 
109  /* evtl. update xvalidation output class */
112  while (current)
113  {
114  current->init_num_runs(m_num_runs);
116  current->init_expose_labels(m_labels);
117  current->post_init();
118  SG_UNREF(current);
119  current=(CCrossValidationOutput*)
121  }
122 
123  /* perform all the x-val runs */
124  SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs)
125  for (index_t i=0; i <m_num_runs; ++i)
126  {
127 
128  /* evtl. update xvalidation output class */
130  while (current)
131  {
132  current->update_run_index(i);
133  SG_UNREF(current);
134  current=(CCrossValidationOutput*)
136  }
137 
138  SG_DEBUG("entering cross-validation run %d \n", i)
139  results[i]=evaluate_one_run();
140  SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])
141  }
142 
143  /* construct evaluation result */
145  result->mean=CStatistics::mean(results);
146  if (m_num_runs>1)
147  result->std_dev=CStatistics::std_deviation(results);
148  else
149  result->std_dev=0;
150 
151  /* unlock machine if it was locked in this method */
153  {
155  m_do_unlock=false;
156  }
157 
158  SG_DEBUG("leaving %s::evaluate()\n", get_name())
159 
160  SG_REF(result);
161  return result;
162 }
163 
164 void CCrossValidation::set_num_runs(int32_t num_runs)
165 {
166  if (num_runs <1)
167  SG_ERROR("%d is an illegal number of repetitions\n", num_runs)
168 
169  m_num_runs=num_runs;
170 }
171 
173 {
174  SG_DEBUG("entering %s::evaluate_one_run()\n", get_name())
176 
177  SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets)
178 
179  /* build index sets */
181 
182  /* results array */
183  SGVector<float64_t> results(num_subsets);
184 
185  /* different behavior whether data is locked or not */
186  if (m_machine->is_data_locked())
187  {
188  SG_DEBUG("starting locked evaluation\n", get_name())
189  /* do actual cross-validation */
190  for (index_t i=0; i <num_subsets; ++i)
191  {
192  /* evtl. update xvalidation output class */
195  while (current)
196  {
197  current->update_fold_index(i);
198  SG_UNREF(current);
199  current=(CCrossValidationOutput*)
201  }
202 
203  /* index subset for training, will be freed below */
204  SGVector<index_t> inverse_subset_indices =
206 
207  /* train machine on training features */
208  m_machine->train_locked(inverse_subset_indices);
209 
210  /* feature subset for testing */
211  SGVector<index_t> subset_indices =
213 
214  /* evtl. update xvalidation output class */
216  while (current)
217  {
218  current->update_train_indices(inverse_subset_indices, "\t");
219  current->update_trained_machine(m_machine, "\t");
220  SG_UNREF(current);
221  current=(CCrossValidationOutput*)
223  }
224 
225  /* produce output for desired indices */
226  CLabels* result_labels=m_machine->apply_locked(subset_indices);
227  SG_REF(result_labels);
228 
229  /* set subset for testing labels */
230  m_labels->add_subset(subset_indices);
231 
232  /* evaluate against own labels */
233  m_evaluation_criterion->set_indices(subset_indices);
234  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
235 
236  /* evtl. update xvalidation output class */
238  while (current)
239  {
240  current->update_test_indices(subset_indices, "\t");
241  current->update_test_result(result_labels, "\t");
242  current->update_test_true_result(m_labels, "\t");
243  current->post_update_results();
244  current->update_evaluation_result(results[i], "\t");
245  SG_UNREF(current);
246  current=(CCrossValidationOutput*)
248  }
249 
250  /* remove subset to prevent side effects */
252 
253  /* clean up */
254  SG_UNREF(result_labels);
255 
256  SG_DEBUG("done locked evaluation\n", get_name())
257  }
258  }
259  else
260  {
261  SG_DEBUG("starting unlocked evaluation\n", get_name())
262  /* tell machine to store model internally
263  * (otherwise changing subset of features will kaboom the classifier) */
265 
266  /* do actual cross-validation */
267 
268  //TODO parallel xvalidation needs some serious fixing, see #3743
269  //#pragma omp parallel for
270  for (index_t i=0; i <num_subsets; ++i)
271  {
272  CMachine* machine;
273  CFeatures* features;
274  CLabels* labels;
275  CEvaluation* evaluation_criterion;
276 
277  if (get_global_parallel()->get_num_threads()==1)
278  {
279  machine=m_machine;
280  features=m_features;
281  evaluation_criterion=m_evaluation_criterion;
282  }
283  else
284  {
285  machine=(CMachine*)m_machine->clone();
286  features=(CFeatures*)m_features->clone();
287  evaluation_criterion=(CEvaluation*)m_evaluation_criterion->clone();
288  }
289 
290  /* evtl. update xvalidation output class */
291  CCrossValidationOutput* current;
292  #pragma omp critical
293  {
294  current=(CCrossValidationOutput*)
296  while (current)
297  {
298  current->update_fold_index(i);
299  SG_UNREF(current);
300  current=(CCrossValidationOutput*)
302  }
303  }
304 
305  /* set feature subset for training */
306  SGVector<index_t> inverse_subset_indices=
308 
309  features->add_subset(inverse_subset_indices);
310 
311  /* set label subset for training */
312  if (get_global_parallel()->get_num_threads()==1)
313  labels=m_labels;
314  else
315  labels=machine->get_labels();
316  labels->add_subset(inverse_subset_indices);
317 
318  SG_DEBUG("training set %d:\n", i)
319  if (io->get_loglevel()==MSG_DEBUG)
320  {
321  SGVector<index_t>::display_vector(inverse_subset_indices.vector,
322  inverse_subset_indices.vlen, "training indices");
323  }
324 
325  /* train machine on training features and remove subset */
326  SG_DEBUG("starting training\n")
327  machine->train(features);
328  SG_DEBUG("finished training\n")
329 
330  /* evtl. update xvalidation output class */
331  #pragma omp critical
332  {
334  while (current)
335  {
336  current->update_train_indices(inverse_subset_indices, "\t");
337  current->update_trained_machine(machine, "\t");
338  SG_UNREF(current);
339  current=(CCrossValidationOutput*)
341  }
342  }
343 
344  features->remove_subset();
345  labels->remove_subset();
346 
347  /* set feature subset for testing (subset method that stores pointer) */
348  SGVector<index_t> subset_indices =
350  features->add_subset(subset_indices);
351 
352  /* set label subset for testing */
353  labels->add_subset(subset_indices);
354 
355  SG_DEBUG("test set %d:\n", i)
356  if (io->get_loglevel()==MSG_DEBUG)
357  {
359  subset_indices.vlen, "test indices");
360  }
361 
362  /* apply machine to test features and remove subset */
363  SG_DEBUG("starting evaluation\n")
364  SG_DEBUG("%p\n", features)
365  CLabels* result_labels=machine->apply(features);
366  SG_DEBUG("finished evaluation\n")
367  features->remove_subset();
368  SG_REF(result_labels);
369 
370  /* evaluate */
371  results[i]=evaluation_criterion->evaluate(result_labels, labels);
372  SG_DEBUG("result on fold %d is %f\n", i, results[i])
373 
374  /* evtl. update xvalidation output class */
375  #pragma omp critical
376  {
378  while (current)
379  {
380  current->update_test_indices(subset_indices, "\t");
381  current->update_test_result(result_labels, "\t");
382  current->update_test_true_result(labels, "\t");
383  current->post_update_results();
384  current->update_evaluation_result(results[i], "\t");
385  SG_UNREF(current);
386  current=(CCrossValidationOutput*)
388  }
389  }
390 
391  /* clean up, remove subsets */
392  labels->remove_subset();
393  if (get_global_parallel()->get_num_threads()!=1)
394  {
395  SG_UNREF(machine);
396  SG_UNREF(features);
397  SG_UNREF(labels);
398  SG_UNREF(evaluation_criterion);
399  }
400  SG_UNREF(result_labels);
401  }
402 
403  SG_DEBUG("done unlocked evaluation\n", get_name())
404  }
405 
406  /* build arithmetic mean of results */
408 
409  SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name())
410  return mean;
411 }
412 
414  CCrossValidationOutput* cross_validation_output)
415 {
416  m_xval_outputs->append_element(cross_validation_output);
417 }
virtual void update_fold_index(index_t fold_index, const char *prefix="")
virtual void build_subsets()=0
virtual void update_train_indices(SGVector< index_t > indices, const char *prefix="")
Parallel * get_global_parallel()
Definition: SGObject.cpp:310
CSGObject * get_next_element()
Definition: List.h:185
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:187
int32_t index_t
Definition: common.h:72
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual CSGObject * clone()
Definition: SGObject.cpp:729
static float64_t std_deviation(SGVector< float64_t > values)
Definition: Statistics.cpp:119
virtual CEvaluationResult * evaluate()
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
virtual void update_test_true_result(CLabels *results, const char *prefix="")
Abstract base class for all splitting types. Takes a CLabels instance and generates a desired number ...
virtual void init_num_runs(index_t num_runs, const char *prefix="")
#define SG_ERROR(...)
Definition: SGIO.h:128
#define REQUIRE(x,...)
Definition: SGIO.h:205
virtual void update_test_indices(SGVector< index_t > indices, const char *prefix="")
std::enable_if<!std::is_same< T, complex128_t >::value, float64_t >::type mean(const Container< T > &a)
type to encapsulate the results of an evaluation run.
virtual const char * get_name() const
Definition: Machine.h:309
virtual bool train_locked(SGVector< index_t > indices)
Definition: Machine.h:240
#define SG_REF(x)
Definition: SGObject.h:52
void set_num_runs(int32_t num_runs)
A generic learning machine interface.
Definition: Machine.h:143
virtual void set_indices(SGVector< index_t > indices)
Definition: Evaluation.h:63
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:396
virtual void update_trained_machine(CMachine *machine, const char *prefix="")
index_t vlen
Definition: SGVector.h:545
CSGObject * get_first_element()
Definition: List.h:151
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:107
Class for managing individual folds in cross-validation.
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:125
double float64_t
Definition: common.h:60
virtual void data_unlock()
Definition: Machine.cpp:143
virtual const char * get_name() const
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:112
virtual void remove_subset()
Definition: Labels.cpp:49
Abstract class that contains the result generated by the MachineEvaluation class. ...
Machine Evaluation is an abstract class that evaluates a machine according to some criterion...
virtual CLabels * get_labels()
Definition: Machine.cpp:76
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
SGVector< index_t > generate_subset_inverse(index_t subset_idx)
static floatmax_t mean(SGVector< T > vec)
Definition: Statistics.h:42
EMessageType get_loglevel() const
Definition: SGIO.cpp:315
virtual void update_test_result(CLabels *results, const char *prefix="")
virtual bool supports_locking() const
Definition: Machine.h:297
virtual float64_t evaluate_one_run()
#define SG_UNREF(x)
Definition: SGObject.h:53
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
SGVector< index_t > generate_subset_indices(index_t subset_idx)
virtual void remove_subset()
Definition: Features.cpp:322
virtual void update_evaluation_result(float64_t result, const char *prefix="")
The class Features is the base class of all feature objects.
Definition: Features.h:68
bool append_element(CSGObject *data)
Definition: List.h:331
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:39
void add_cross_validation_output(CCrossValidationOutput *cross_validation_output)
#define SG_WARNING(...)
Definition: SGIO.h:127
#define SG_ADD(...)
Definition: SGObject.h:94
virtual void init_expose_labels(CLabels *labels)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
bool is_data_locked() const
Definition: Machine.h:300
virtual void init_num_folds(index_t num_folds, const char *prefix="")
virtual void update_run_index(index_t run_index, const char *prefix="")
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:40
CSplittingStrategy * m_splitting_strategy
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
Class List implements a doubly connected list for low-level-objects.
Definition: List.h:84
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:152

SHOGUN Machine Learning Toolbox - Documentation