SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StructuredOutputMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Written (W) 2013 Thoralf Klein
9  * Written (W) 2012 Fernando José Iglesias García
10  * Copyright (C) 2012 Fernando José Iglesias García
11  */
12 
17 
18 using namespace shogun;
19 
21 : CMachine(), m_model(NULL), m_surrogate_loss(NULL)
22 {
23  register_parameters();
24 }
25 
27  CStructuredModel* model,
28  CStructuredLabels* labs)
29 : CMachine(), m_model(model), m_surrogate_loss(NULL)
30 {
31  SG_REF(m_model);
32  set_labels(labs);
33  register_parameters();
34 }
35 
37 {
41 }
42 
44 {
45  SG_REF(model);
47  m_model = model;
48 }
49 
51 {
52  SG_REF(m_model);
53  return m_model;
54 }
55 
56 void CStructuredOutputMachine::register_parameters()
57 {
58  SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE);
59  SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE);
60  SG_ADD(&m_verbose, "verbose", "Verbosity flag", MS_NOT_AVAILABLE);
61  SG_ADD((CSGObject**)&m_helper, "helper", "Training helper", MS_NOT_AVAILABLE);
62 
63  m_verbose = false;
64  m_helper = NULL;
65 }
66 
68 {
70  REQUIRE(m_model != NULL, "please call set_model() before set_labels()\n");
72 }
73 
75 {
77 }
78 
80 {
81  return m_model->get_features();
82 }
83 
85 {
86  SG_REF(loss);
88  m_surrogate_loss = loss;
89 }
90 
92 {
94  return m_surrogate_loss;
95 }
96 
98 {
99  int32_t dim = m_model->get_dim();
100 
101  int32_t from=0, to=0;
102  CFeatures* features = get_features();
103  if (info)
104  {
105  from = info->m_from;
106  to = (info->m_N == 0) ? features->get_num_vectors() : from+info->m_N;
107  }
108  else
109  {
110  from = 0;
111  to = features->get_num_vectors();
112  }
113  SG_UNREF(features);
114 
115  float64_t R = 0.0;
116  for (int32_t i=0; i<dim; i++)
117  subgrad[i] = 0;
118 
119  for (int32_t i=from; i<to; i++)
120  {
121  CResultSet* result = m_model->argmax(SGVector<float64_t>(W,dim,false), i, true);
122  SGVector<float64_t> psi_pred = result->psi_pred;
123  SGVector<float64_t> psi_truth = result->psi_truth;
124  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
125  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
126  R += result->score;
127  SG_UNREF(result);
128  }
129 
130  return R;
131 }
132 
134 {
135  SG_ERROR("%s::risk_nslack_slack_rescale() has not been implemented!\n", get_name());
136  return 0.0;
137 }
138 
140 {
141  SG_ERROR("%s::risk_1slack_margin_rescale() has not been implemented!\n", get_name());
142  return 0.0;
143 }
144 
146 {
147  SG_ERROR("%s::risk_1slack_slack_rescale() has not been implemented!\n", get_name());
148  return 0.0;
149 }
150 
152 {
153  SG_ERROR("%s::risk_customized_formulation() has not been implemented!\n", get_name());
154  return 0.0;
155 }
156 
158  TMultipleCPinfo* info, EStructRiskType rtype)
159 {
160  float64_t ret = 0.0;
161  switch(rtype)
162  {
164  ret = risk_nslack_margin_rescale(subgrad, W, info);
165  break;
167  ret = risk_nslack_slack_rescale(subgrad, W, info);
168  break;
170  ret = risk_1slack_margin_rescale(subgrad, W, info);
171  break;
173  ret = risk_1slack_slack_rescale(subgrad, W, info);
174  break;
175  case CUSTOMIZED_RISK:
176  ret = risk_customized_formulation(subgrad, W, info);
177  break;
178  default:
179  SG_ERROR("%s::risk(): cannot recognize the risk type!\n", get_name());
180  ret = -1;
181  break;
182  }
183  return ret;
184 }
185 
187 {
188  if (m_helper == NULL)
189  {
190  SG_ERROR("%s::get_helper(): no helper has been created!"
191  "Please set verbose before training!\n", get_name());
192  }
193 
194  SG_REF(m_helper);
195  return m_helper;
196 }
197 
199 {
200  m_verbose = verbose;
201 }
202 
204 {
205  return m_verbose;
206 }

SHOGUN Machine Learning Toolbox - Documentation