SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
PRCEvaluation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011 Sergey Lisitsyn
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
15 
16 using namespace shogun;
17 
19 {
20 }
21 
23 {
24  ASSERT(predicted && ground_truth)
25  ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels())
26  ASSERT(predicted->get_label_type()==LT_BINARY)
27  ASSERT(ground_truth->get_label_type()==LT_BINARY)
28  ground_truth->ensure_valid();
29 
30  // number of true positive examples
31  float64_t tp = 0.0;
32  int32_t i;
33 
34  // total number of positive labels in predicted
35  int32_t pos_count=0;
36 
37  // initialize number of labels and labels
38  SGVector<float64_t> orig_labels = predicted->get_values();
39  int32_t length = orig_labels.vlen;
40  float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length);
41 
42  // get indexes for sort
43  int32_t* idxs = SG_MALLOC(int32_t, length);
44  for(i=0; i<length; i++)
45  idxs[i] = i;
46 
47  // sort indexes by labels ascending
48  CMath::qsort_backward_index(labels,idxs,length);
49 
50  // clean and initialize graph and auPRC
51  SG_FREE(labels);
52  m_PRC_graph = SGMatrix<float64_t>(2,length);
54  m_auPRC = 0.0;
55 
56  // get total numbers of positive and negative labels
57  for (i=0; i<length; i++)
58  {
59  if (ground_truth->get_value(i) > 0)
60  pos_count++;
61  }
62 
63  // assure number of positive examples is >0
64  ASSERT(pos_count>0)
65 
66  // create PRC curve
67  for (i=0; i<length; i++)
68  {
69  // update number of true positive examples
70  if (ground_truth->get_value(idxs[i]) > 0)
71  tp += 1.0;
72 
73  // precision (x)
74  m_PRC_graph[2*i] = tp/float64_t(i+1);
75  // recall (y)
76  m_PRC_graph[2*i+1] = tp/float64_t(pos_count);
77 
78  m_thresholds[i]= predicted->get_value(idxs[i]);
79  }
80 
81  // calc auRPC using area under curve
82  m_auPRC = CMath::area_under_curve(m_PRC_graph.matrix,length,true);
83 
84  // set computed indicator
85  m_computed = true;
86 
87  SG_FREE(idxs);
88  return m_auPRC;
89 }
90 
92 {
93  if (!m_computed)
94  SG_ERROR("Uninitialized, please call evaluate first")
95 
96  return m_PRC_graph;
97 }
98 
100 {
101  if (!m_computed)
102  SG_ERROR("Uninitialized, please call evaluate first")
103 
104  return m_thresholds;
105 }
106 
108 {
109  if (!m_computed)
110  SG_ERROR("Uninitialized, please call evaluate first")
111 
112  return m_auPRC;
113 }
SGMatrix< float64_t > get_PRC()
virtual float64_t get_value(int32_t idx)
Definition: Labels.cpp:59
virtual ELabelType get_label_type() const =0
binary labels +1/-1
Definition: LabelTypes.h:18
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
#define SG_ERROR(...)
Definition: SGIO.h:129
static T * clone_vector(const T *vec, int32_t len)
Definition: SGVector.cpp:213
SGVector< float64_t > get_thresholds()
index_t vlen
Definition: SGVector.h:492
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
static float64_t area_under_curve(float64_t *xy, int32_t len, bool reversed)
Definition: Math.h:946
SGVector< float64_t > m_thresholds
Definition: PRCEvaluation.h:78
virtual SGVector< float64_t > get_values()
Definition: Labels.cpp:90
SGMatrix< float64_t > m_PRC_graph
Definition: PRCEvaluation.h:75
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual void ensure_valid(const char *context=NULL)=0
static void qsort_backward_index(T1 *output, T2 *index, int32_t size)
Definition: Math.h:2246

SHOGUN Machine Learning Toolbox - Documentation