SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MulticlassMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2011 Soeren Sonnenburg
8  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
9  * Written (W) 2013 Shell Hu and Heiko Strathmann
10  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
11  */
12 
20 
21 using namespace shogun;
22 
24 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
25  m_machine(NULL)
26 {
28  register_parameters();
29 }
30 
32  CMulticlassStrategy *strategy,
33  CMachine* machine, CLabels* labs)
34 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
35 {
36  SG_REF(strategy);
37  set_labels(labs);
38  SG_REF(machine);
39  m_machine = machine;
40  register_parameters();
41 
42  if (labs)
43  init_strategy();
44 }
45 
47 {
50 }
51 
53 {
55  if (lab)
56  init_strategy();
57 }
58 
59 void CMulticlassMachine::register_parameters()
60 {
61  SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
62  SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
63 }
64 
66 {
67  int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
69 }
70 
72 {
73  CMachine *machine = (CMachine*)m_machines->get_element(i);
74  ASSERT(machine)
75  CBinaryLabels* output = machine->apply_binary();
76  SG_UNREF(machine);
77  return output;
78 }
79 
81 {
82  CMachine *machine = get_machine(i);
83  float64_t output = 0.0;
84  // dirty hack
85  if (dynamic_cast<CLinearMachine*>(machine))
86  output = ((CLinearMachine*)machine)->apply_one(num);
87  if (dynamic_cast<CKernelMachine*>(machine))
88  output = ((CKernelMachine*)machine)->apply_one(num);
89  SG_UNREF(machine);
90  return output;
91 }
92 
94 {
95  SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
96  get_name(), data ? data->get_name() : "NULL", data);
97 
98  CMulticlassLabels* return_labels=NULL;
99 
100  if (data)
102  else
104 
105  if (is_ready())
106  {
107  /* num vectors depends on whether data is provided */
108  int32_t num_vectors=data ? data->get_num_vectors() :
110 
111  int32_t num_machines=m_machines->get_num_elements();
112  if (num_machines <= 0)
113  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
114 
115  CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
116 
117  // if outputs are prob, only one confidence for each class
118  int32_t num_classes=m_multiclass_strategy->get_num_classes();
120 
121  if (heuris!=PROB_HEURIS_NONE)
122  result->allocate_confidences_for(num_classes);
123  else
124  result->allocate_confidences_for(num_machines);
125 
126  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
127  SGVector<float64_t> As(num_machines);
128  SGVector<float64_t> Bs(num_machines);
129 
130  for (int32_t i=0; i<num_machines; ++i)
131  {
132  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
133 
134  if (heuris==OVA_SOFTMAX)
135  {
136  CStatistics::SigmoidParamters params = CStatistics::fit_sigmoid(outputs[i]->get_values());
137  As[i] = params.a;
138  Bs[i] = params.b;
139  }
140 
141  if (heuris!=PROB_HEURIS_NONE && heuris!=OVA_SOFTMAX)
142  outputs[i]->scores_to_probabilities(0,0);
143  }
144 
145  SGVector<float64_t> output_for_i(num_machines);
146  SGVector<float64_t> r_output_for_i(num_machines);
147  if (heuris!=PROB_HEURIS_NONE)
148  r_output_for_i.resize_vector(num_classes);
149 
150  for (int32_t i=0; i<num_vectors; i++)
151  {
152  for (int32_t j=0; j<num_machines; j++)
153  output_for_i[j] = outputs[j]->get_value(i);
154 
155  if (heuris==PROB_HEURIS_NONE)
156  {
157  r_output_for_i = output_for_i;
158  }
159  else
160  {
161  if (heuris==OVA_SOFTMAX)
162  m_multiclass_strategy->rescale_outputs(output_for_i,As,Bs);
163  else
164  m_multiclass_strategy->rescale_outputs(output_for_i);
165 
166  // only first num_classes are returned
167  for (int32_t r=0; r<num_classes; r++)
168  r_output_for_i[r] = output_for_i[r];
169 
170  SG_DEBUG("%s::apply_multiclass(): sum(r_output_for_i) = %f\n",
171  get_name(), SGVector<float64_t>::sum(r_output_for_i.vector,num_classes));
172  }
173 
174  // use rescaled outputs for label decision
175  result->set_label(i, m_multiclass_strategy->decide_label(r_output_for_i));
176  result->set_multiclass_confidences(i, r_output_for_i);
177  }
178 
179  for (int32_t i=0; i < num_machines; ++i)
180  SG_UNREF(outputs[i]);
181 
182  SG_FREE(outputs);
183 
184  return_labels=result;
185  }
186  else
187  SG_ERROR("Not ready")
188 
189 
190  SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
191  get_name(), data ? data->get_name() : "NULL", data);
192  return return_labels;
193 }
194 
196 {
197  CMultilabelLabels* return_labels=NULL;
198 
199  if (data)
201  else
203 
204  if (is_ready())
205  {
206  /* num vectors depends on whether data is provided */
207  int32_t num_vectors=data ? data->get_num_vectors() :
209 
210  int32_t num_machines=m_machines->get_num_elements();
211  if (num_machines <= 0)
212  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
213  REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available")
214 
215  CMultilabelLabels* result=new CMultilabelLabels(num_vectors, n_outputs);
216  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
217 
218  for (int32_t i=0; i < num_machines; ++i)
219  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
220 
221  SGVector<float64_t> output_for_i(num_machines);
222  for (int32_t i=0; i<num_vectors; i++)
223  {
224  for (int32_t j=0; j<num_machines; j++)
225  output_for_i[j] = outputs[j]->get_value(i);
226 
227  result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
228  }
229 
230  for (int32_t i=0; i < num_machines; ++i)
231  SG_UNREF(outputs[i]);
232 
233  SG_FREE(outputs);
234 
235  return_labels=result;
236  }
237  else
238  SG_ERROR("Not ready")
239 
240  return return_labels;
241 }
242 
244 {
246 
247  if ( !data && !is_ready() )
248  SG_ERROR("Please provide training data.\n")
249  else
251 
253  CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
254  SG_REF(train_labels);
255  m_machine->set_labels(train_labels);
256 
259  {
261  if (subset.vlen)
262  {
263  train_labels->add_subset(subset);
264  add_machine_subset(subset);
265  }
266 
267  m_machine->train();
269 
270  if (subset.vlen)
271  {
272  train_labels->remove_subset();
274  }
275  }
276 
278  SG_UNREF(train_labels);
279 
280  return true;
281 }
282 
284 {
286 
289 
290  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
291  outputs[i] = get_submachine_output(i, vec_idx);
292 
293  float64_t result = m_multiclass_strategy->decide_label(outputs);
294 
295  return result;
296 }

SHOGUN Machine Learning Toolbox - Documentation