SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8  */
9 
14 
15 using namespace shogun;
16 
17 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) :
19 {
20  m_initialized = false;
21  m_compute_ROC = compute_ROC;
22  m_compute_PRC = compute_PRC;
23  m_compute_conf_matrices = compute_conf_matrices;
24  m_pred_labels = NULL;
25  m_true_labels = NULL;
26  m_num_classes = 0;
28 
29  m_fold_ROC_graphs=NULL;
30  m_conf_matrices=NULL;
31 }
32 
33 
35 {
37  {
38  SG_FREE(m_fold_ROC_graphs);
39  }
40 
42  {
43  SG_FREE(m_fold_PRC_graphs);
44  }
45 
47  {
48  SG_FREE(m_conf_matrices);
49  }
50 
52  {
54  }
55 };
56 
57 
59 {
60  if (m_initialized)
61  SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n")
62 
63  if (m_compute_ROC)
64  {
65  SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes)
66  m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
67  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
69  }
70 
71  if (m_compute_PRC)
72  {
73  SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes)
74  m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
75  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
77  }
78 
81 
83 
85  {
87  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
89  }
90 
91  m_initialized = true;
92 }
93 
95 {
96  ASSERT((CMulticlassLabels*)labels)
97  m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
98 }
99 
101 {
102  CROCEvaluation eval_ROC;
103  CPRCEvaluation eval_PRC;
104  int32_t n_evals = m_binary_evaluations->get_num_elements();
105  for (int32_t c=0; c<m_num_classes; c++)
106  {
107  SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c)
108  CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
109  CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
110  if (m_compute_ROC)
111  {
112  eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
114  eval_ROC.get_ROC();
115  }
116  if (m_compute_PRC)
117  {
118  eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
120  eval_PRC.get_PRC();
121  }
122 
123  for (int32_t i=0; i<n_evals; i++)
124  {
126  m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
127  evaluator->evaluate(pred_labels_binary, true_labels_binary);
128  SG_UNREF(evaluator);
129  }
130 
131  SG_UNREF(pred_labels_binary);
132  SG_UNREF(true_labels_binary);
133  }
134  CMulticlassAccuracy accuracy;
135 
137 
139  {
141  }
142 }
143 
145 {
146  m_pred_labels = (CMulticlassLabels*)results;
147 }
148 
150 {
151  m_true_labels = (CMulticlassLabels*)results;
152 }
153 
SGMatrix< float64_t > get_PRC()
CBinaryLabels * get_binary_for_class(int32_t i)
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
#define SG_ERROR(...)
Definition: SGIO.h:129
SGMatrix< float64_t > get_ROC()
The class MulticlassAccuracy used to compute accuracy of multiclass classification.
Multiclass Labels for multi-class classification.
virtual void update_test_true_result(CLabels *results, const char *prefix="")
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
Class ROCEvalution used to evaluate ROC (Receiver Operating Characteristic) and an area under ROC cur...
Definition: ROCEvaluation.h:32
Class for managing individual folds in cross-validation.
#define ASSERT(x)
Definition: SGIO.h:201
static SGMatrix< int32_t > get_confusion_matrix(CLabels *predicted, CLabels *ground_truth)
CCrossValidationMulticlassStorage(bool compute_ROC=true, bool compute_PRC=false, bool compute_conf_matrices=false)
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
CSGObject * get_element_safe(int32_t index) const
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
Class PRCEvaluation used to evaluate PRC (Precision Recall Curve) and an area under PRC curve (auPRC)...
Definition: PRCEvaluation.h:27
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
The class TwoClassEvaluation, a base class used to evaluate binary classification labels...
virtual void update_test_result(CLabels *results, const char *prefix="")

SHOGUN Machine Learning Toolbox - Documentation