SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
C45ClassifierTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
35 
36 using namespace shogun;
37 
39 
42 {
43  init();
44 }
45 
47 {
48 }
49 
51 {
52  REQUIRE(data, "Data required for classification in apply_multiclass\n")
53 
54  // apply multiclass starting from root
55  node_t* current=get_root();
56  CMulticlassLabels* ret=apply_multiclass_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current, true);
57 
58  SG_UNREF(current);
59  return ret;
60 }
61 
63 {
64  node_t* current=get_root();
65  prune_tree_from_current_node(validation_data,validation_labels,current,epsilon);
66 
67  SG_UNREF(current);
68 }
69 
71 {
72  return m_certainty;
73 }
74 
76 {
77  m_weights=w;
78  m_weights_set=true;
79 }
80 
82 {
83  return m_weights;
84 }
85 
87 {
88  m_weights=SGVector<float64_t>();
89  m_weights_set=false;
90 }
91 
93 {
94  m_nominal=ft;
95  m_types_set=true;
96 }
97 
99 {
100  return m_nominal;
101 }
102 
104 {
105  m_nominal=SGVector<bool>();
106  m_types_set=false;
107 }
108 
110 {
111  REQUIRE(data,"Data required for training\n")
112  REQUIRE(data->get_feature_class()==C_DENSE,"Dense data required for training\n")
113 
114  int32_t num_features=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_features();
115  int32_t num_vectors=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_vectors();
116 
117  if (m_weights_set)
118  {
119  REQUIRE(m_weights.vlen==num_vectors,"Length of weights vector (currently %d) should be same as"
120  " number of vectors in data (presently %d)",m_weights.vlen,num_vectors)
121  }
122  else
123  {
124  // all weights are equal to 1
125  m_weights=SGVector<float64_t>(num_vectors);
126  m_weights.fill_vector(m_weights.vector,m_weights.vlen,1.0);
127  }
128 
129  if (m_types_set)
130  {
131  REQUIRE(m_nominal.vlen==num_features,"Length of m_nominal vector (currently %d) should "
132  "be same as number of features in data (presently %d)",m_nominal.vlen,num_features)
133  }
134  else
135  {
136  SG_WARNING("Feature types are not specified. All features are considered as continuous in training")
137  m_nominal=SGVector<bool>(num_features);
138  m_nominal.fill_vector(m_nominal.vector,m_nominal.vlen,false);
139  }
140 
141  SGVector<int32_t> feature_ids(num_features);
142  feature_ids.range_fill();
143 
144  set_root(C45train(data, m_weights, dynamic_cast<CMulticlassLabels*>(m_labels), feature_ids, 0));
145 
146  return true;
147 }
148 
149 CTreeMachineNode<C45TreeNodeData>* CC45ClassifierTree::C45train(CFeatures* data, SGVector<float64_t> weights,
150  CMulticlassLabels* class_labels, SGVector<int32_t> feature_id_vector, int32_t level)
151 {
152  REQUIRE(data,"data matrix cannot be NULL\n");
153  REQUIRE(class_labels,"class labels cannot be NULL\n");
154  node_t* node=new node_t();
155  CDenseFeatures<float64_t>* feats=dynamic_cast<CDenseFeatures<float64_t>*>(data);
156  int32_t num_vecs=feats->get_num_vectors();
157 
158  // set class_label for the node as the mode of occurring multiclass labels
159  SGVector<float64_t> labels=class_labels->get_labels_copy();
160  labels.qsort();
161 
162  int32_t most_label=labels[0];
163  int32_t most_weight=weights[0];
164  int32_t weight=weights[0];
165 
166  for (int32_t i=1; i<labels.vlen; i++)
167  {
168  if (labels[i]==labels[i-1])
169  {
170  weight+=weights[i];
171  }
172  else if (weight>most_weight)
173  {
174  most_weight=weight;
175  most_label=labels[i-1];
176  weight=weights[i];
177  }
178  else
179  {
180  weight=weights[i];
181  }
182  }
183 
184  if (weight>most_weight)
185  {
186  most_weight=weight;
187  most_label=labels[labels.vlen-1];
188  }
189 
190  node->data.class_label=most_label;
191  node->data.total_weight=weights.sum(weights.vector,weights.vlen);
192  node->data.weight_minus=0.0;
193  for (int32_t i=0;i<labels.vlen;i++)
194  {
195  if (class_labels->get_label(i)!=most_label)
196  node->data.weight_minus+=weights[i];
197  }
198 
199  // if all samples belong to the same class
200  if (class_labels->get_unique_labels().size()==1)
201  return node;
202 
203  // if no feature is left
204  if (feature_id_vector.vlen==0)
205  return node;
206 
207  // if all remaining attributes are identical
208  bool flag=true;
209  for (int32_t i=1;i<num_vecs;i++)
210  {
211  for (int32_t j=0;j<feats->get_num_features();j++)
212  {
213  if (feats->get_feature_vector(i)[j]!=feats->get_feature_vector(i-1)[j])
214  {
215  flag=false;
216  break;
217  }
218  }
219 
220  if (!flag)
221  break;
222  }
223 
224  if (flag)
225  return node;
226 
227  // else get the feature with the highest informational gain. threshold is used for continuous features only.
228  float64_t max=0;
229  int32_t best_feature_index=-1;
230  float64_t threshold=0.;
231  for (int32_t i=0; i<feats->get_num_features(); i++)
232  {
233  if (m_nominal[feature_id_vector[i]])
234  {
235  float64_t gain=informational_gain_attribute(i,feats,weights,class_labels);
236  if (gain>=max)
237  {
238  max=gain;
239  best_feature_index=i;
240  }
241  }
242  else
243  {
244  SGVector<float64_t> feature_values(num_vecs);
246  for (int32_t k=0; k<num_vecs; k++)
247  {
248  feature_values[k]=(feats->get_feature_vector(k))[i];
249 
250  if (!CMath::fequals(feature_values[k],MISSING,0) && feature_values[k]>max_value)
251  max_value=feature_values[k];
252  }
253 
254  for (int32_t k=0;k<num_vecs;k++)
255  {
256  if (feature_values[k]!=max_value && !CMath::fequals(feature_values[k],MISSING,0))
257  {
258  // form temporary dense features to calculate gain (continuous->nominal conversion)
259  float64_t z=feature_values[k];
260  SGMatrix<float64_t> temp_feat_mat=SGMatrix<float64_t>(1,num_vecs);
261  for (int32_t l=0;l<num_vecs;l++)
262  {
263  if (CMath::fequals(feature_values[l],MISSING,0))
264  temp_feat_mat(0,l)=MISSING;
265  else if (feature_values[l]<=z)
266  temp_feat_mat(0,l)=0.;
267  else
268  temp_feat_mat(0,l)=1.;
269  }
270 
271  CDenseFeatures<float64_t>* temp_feats=new CDenseFeatures<float64_t>(temp_feat_mat);
272  float64_t gain=informational_gain_attribute(0,temp_feats,weights,class_labels);
273  if (gain>max)
274  {
275  threshold=z;
276  max=gain;
277  best_feature_index=i;
278  }
279 
280  SG_UNREF(temp_feats);
281  }
282  }
283  }
284  }
285 
286  // feature cache for data restoration if feature is continuous
287  SGVector<float64_t> feature_cache(num_vecs);
288 
289  // if continuous attribute - split feature values about threshold
290  if (!m_nominal[feature_id_vector[best_feature_index]])
291  {
292  // convert continuous feature to nominal. Store cache for restoration
293  for(int32_t p=0;p<num_vecs;p++)
294  {
295  feature_cache[p]=feats->get_feature_vector(p)[best_feature_index];
296  if (CMath::fequals(feature_cache[p],MISSING,0))
297  continue;
298 
299  if (feature_cache[p]<=threshold)
300  feats->get_feature_vector(p)[best_feature_index]=0.;
301  else
302  feats->get_feature_vector(p)[best_feature_index]=1.;
303  }
304  }
305 
306  // get feature values for the best feature chosen - shorthand for the features values of the best feature chosen
307  SGVector<float64_t> best_feature_values(num_vecs);
308  for (int32_t i=0; i<num_vecs; i++)
309  best_feature_values[i]=(feats->get_feature_vector(i))[best_feature_index];
310 
311  // prepare vector of unique feature values excluding MISSING , also calculate total weight associated with missing attributes
312  int32_t num_missing=0;
313  float64_t weight_missing=0.;
314  for (int32_t j=0;j<num_vecs;j++)
315  {
316  if (CMath::fequals(best_feature_values[j],MISSING,0))
317  {
318  num_missing++;
319  weight_missing+=weights[j];
320  }
321  }
322 
323  SGVector<float64_t> best_features_unique(num_vecs-num_missing);
324  int32_t index=0;
325  for (int32_t j=0;j<num_vecs;j++)
326  {
327  if (!CMath::fequals(best_feature_values[j],MISSING,0))
328  best_features_unique[index++]=best_feature_values[j];
329  }
330 
331  int32_t uniques_num=best_features_unique.unique(best_features_unique.vector,best_features_unique.vlen);
332 
333  // create child node for each unique value
334  for (int32_t i=0; i<uniques_num; i++)
335  {
336  //compute the number of vectors with active attribute value
337  int32_t num_cols=0;
338  float64_t active_feature_value=best_features_unique[i];
339 
340  for (int32_t j=0; j<num_vecs; j++)
341  {
342  if (active_feature_value==best_feature_values[j] || CMath::fequals(best_feature_values[j],MISSING,0))
343  num_cols++;
344  }
345 
346  SGMatrix<float64_t> mat(feats->get_num_features()-1, num_cols);
347  SGVector<float64_t> new_labels_vector(num_cols);
348  SGVector<float64_t> new_weights(num_cols);
349 
350  int32_t cnt=0;
351  //choose the samples that have the active feature value
352  for (int32_t j=0; j<num_vecs; j++)
353  {
354  SGVector<float64_t> sample=feats->get_feature_vector(j);
355  if (active_feature_value==sample[best_feature_index] || CMath::fequals(sample[best_feature_index],MISSING,0))
356  {
357  int32_t idx=-1;
358  for (int32_t k=0; k<sample.size(); k++)
359  {
360  if (k!=best_feature_index)
361  mat(++idx, cnt) = sample[k];
362  }
363 
364  new_labels_vector[cnt]=class_labels->get_labels()[j];
365  if (!CMath::fequals(sample[best_feature_index],MISSING,0))
366  new_weights[cnt]=weights[j];
367  else
368  new_weights[cnt]=0.;
369 
370  cnt++;
371  }
372  }
373 
374  // rectify weights of data points with missing attributes (set zero previously)
375  float64_t numer=new_weights.sum(new_weights.vector,new_weights.vlen);
376  float64_t rec_weight=numer/(node->data.total_weight-weight_missing);
377  cnt=0;
378  for (int32_t j=0;j<num_vecs;j++)
379  {
380  if (CMath::fequals(best_feature_values[j],MISSING,0))
381  new_weights[cnt++]=rec_weight;
382  else if (best_feature_values[j]==active_feature_value)
383  cnt++;
384  }
385 
386  //remove the best_attribute from the remaining attributes index vector
387  SGVector<int32_t> new_feature_id_vector(feature_id_vector.vlen-1);
388  cnt=-1;
389  for (int32_t j=0;j<feature_id_vector.vlen;j++)
390  {
391  if (j!=best_feature_index)
392  new_feature_id_vector[++cnt]=feature_id_vector[j];
393  }
394 
395  // new data & label for child node
396  CMulticlassLabels* new_class_labels=new CMulticlassLabels(new_labels_vector);
398 
399  // recursion over child nodes
400  node_t* child=C45train(new_data,new_weights,new_class_labels,new_feature_id_vector,level+1);
401  node->data.attribute_id=feature_id_vector[best_feature_index];
402  if (m_nominal[feature_id_vector[best_feature_index]])
403  child->data.transit_if_feature_value=active_feature_value;
404  else
405  child->data.transit_if_feature_value=threshold;
406 
407  node->add_child(child);
408 
409  SG_UNREF(new_class_labels);
410  SG_UNREF(new_data);
411  }
412 
413  // if continuous attribute - restoration required
414  if (!m_nominal[feature_id_vector[best_feature_index]])
415  {
416  // restore data matrix
417  for(int32_t p=0;p<num_vecs;p++)
418  feats->get_feature_vector(p)[best_feature_index]=feature_cache[p];
419  }
420 
421  return node;
422 }
423 
424 void CC45ClassifierTree::prune_tree_from_current_node(CDenseFeatures<float64_t>* feats,
425  CMulticlassLabels* gnd_truth, node_t* current, float64_t epsilon)
426 {
427  // if leaf node then skip pruning
428  if (current->data.attribute_id==-1)
429  return;
430 
431  SGMatrix<float64_t> feature_matrix=feats->get_feature_matrix();
432  CDynamicObjectArray* children=current->get_children();
433 
434  if (m_nominal[current->data.attribute_id])
435  {
436  for (int32_t i=0; i<children->get_num_elements(); i++)
437  {
438  // count number of feature vectors which transit into the child
439  int32_t count=0;
440  node_t* child=dynamic_cast<node_t*>(children->get_element(i));
441 
442  for (int32_t j=0; j<feature_matrix.num_cols; j++)
443  {
444  float64_t child_transit=child->data.transit_if_feature_value;
445 
446  if (child_transit==feature_matrix(current->data.attribute_id,j))
447  count++;
448  }
449 
450  if (count==0)
451  continue;
452 
453  // form new subset of features and labels
454  SGVector<index_t> subset=SGVector<index_t>(count);
455  int32_t k=0;
456 
457  for (int32_t j=0; j<feature_matrix.num_cols;j++)
458  {
459  float64_t child_transit=child->data.transit_if_feature_value;
460 
461  if (child_transit==feature_matrix(current->data.attribute_id,j))
462  {
463  subset[k]=(index_t) j;
464  k++;
465  }
466  }
467 
468  feats->add_subset(subset);
469  gnd_truth->add_subset(subset);
470 
471  // prune the child subtree
472  prune_tree_from_current_node(feats,gnd_truth,child,epsilon);
473 
474  feats->remove_subset();
475  gnd_truth->remove_subset();
476 
477  SG_UNREF(child);
478  }
479  }
480  else
481  {
482  REQUIRE(children->get_num_elements()==2,"The chosen attribute in current node is continuous. Expected number of"
483  " children is 2 but current node has %d children.",children->get_num_elements())
484 
485  node_t* left_child=dynamic_cast<node_t*>(children->get_element(0));
486  node_t* right_child=dynamic_cast<node_t*>(children->get_element(1));
487 
488  int32_t count_left=0;
489  for (int32_t k=0;k<feature_matrix.num_cols;k++)
490  {
491  if (feature_matrix(current->data.attribute_id,k)<=left_child->data.transit_if_feature_value)
492  count_left++;
493  }
494 
495  SGVector<int32_t> left_subset(count_left);
496  SGVector<int32_t> right_subset(feature_matrix.num_cols-count_left);
497  int32_t l=0;
498  int32_t r=0;
499  for (int32_t k=0;k<feature_matrix.num_cols;k++)
500  {
501  if (feature_matrix(current->data.attribute_id,k)<=left_child->data.transit_if_feature_value)
502  left_subset[l++]=k;
503  else
504  right_subset[r++]=k;
505  }
506 
507  // count_left is 0 if entire validation data in current node moves to only right child
508  if (count_left>0)
509  {
510  feats->add_subset(left_subset);
511  gnd_truth->add_subset(left_subset);
512  // prune the left child subtree
513  prune_tree_from_current_node(feats,gnd_truth,left_child,epsilon);
514  feats->remove_subset();
515  gnd_truth->remove_subset();
516  }
517 
518  // count_left is equal to num_cols if entire validation data in current node moves only to left child
519  if (count_left<feature_matrix.num_cols)
520  {
521  feats->add_subset(right_subset);
522  gnd_truth->add_subset(right_subset);
523  // prune the right child subtree
524  prune_tree_from_current_node(feats,gnd_truth,right_child,epsilon);
525  feats->remove_subset();
526  gnd_truth->remove_subset();
527  }
528 
529  SG_UNREF(left_child);
530  SG_UNREF(right_child);
531  }
532 
533  SG_UNREF(children);
534 
535  CMulticlassLabels* predicted_unpruned=apply_multiclass_from_current_node(feats, current);
536  SGVector<float64_t> pruned_labels=SGVector<float64_t>(feature_matrix.num_cols);
537  for (int32_t i=0; i<feature_matrix.num_cols; i++)
538  pruned_labels[i]=current->data.class_label;
539 
540  CMulticlassLabels* predicted_pruned=new CMulticlassLabels(pruned_labels);
541 
543  float64_t unpruned_accuracy=accuracy->evaluate(predicted_unpruned, gnd_truth);
544  float64_t pruned_accuracy=accuracy->evaluate(predicted_pruned, gnd_truth);
545 
546  if (unpruned_accuracy<pruned_accuracy+epsilon)
547  {
548  CDynamicObjectArray* null_children=new CDynamicObjectArray();
549  current->set_children(null_children);
550  SG_UNREF(null_children);
551  }
552 
553  SG_UNREF(accuracy);
554  SG_UNREF(predicted_pruned);
555  SG_UNREF(predicted_unpruned);
556 }
557 
558 float64_t CC45ClassifierTree::informational_gain_attribute(int32_t attr_no, CFeatures* data,
559  SGVector<float64_t> weights, CMulticlassLabels* class_labels)
560 {
561  REQUIRE(data,"Data required for information gain calculation\n")
562  REQUIRE(data->get_feature_class()==C_DENSE,
563  "Dense data required for information gain calculation\n")
564 
565  float64_t gain=0;
566  CDenseFeatures<float64_t>* feats=dynamic_cast<CDenseFeatures<float64_t>*>(data);
567  int32_t num_vecs=feats->get_num_vectors();
568  SGVector<float64_t> gain_attribute_values;
569  SGVector<float64_t> gain_weights=weights;
570  CMulticlassLabels* gain_labels=class_labels;
571 
572  int32_t num_missing=0;
573  for (int32_t i=0;i<num_vecs;i++)
574  {
575  if (CMath::fequals((feats->get_feature_vector(i))[attr_no],MISSING,0))
576  num_missing++;
577  }
578 
579  if (num_missing==0)
580  {
581  gain_attribute_values=SGVector<float64_t>(num_vecs);
582  for (int32_t i=0; i<num_vecs; i++)
583  gain_attribute_values[i]=(feats->get_feature_vector(i))[attr_no];
584  }
585  else
586  {
587  gain_attribute_values=SGVector<float64_t>(num_vecs-num_missing);
588  gain_weights=SGVector<float64_t>(num_vecs-num_missing);
589  SGVector<float64_t> label_vector(num_vecs-num_missing);
590  int32_t index=0;
591  for (int32_t i=0; i<num_vecs; i++)
592  {
593  if (!CMath::fequals((feats->get_feature_vector(i))[attr_no],MISSING,0))
594  {
595  gain_attribute_values[index]=(feats->get_feature_vector(i))[attr_no];
596  gain_weights[index]=weights[i];
597  label_vector[index++]=class_labels->get_label(i);
598  }
599  }
600 
601  num_vecs-=num_missing;
602  gain_labels=new CMulticlassLabels(label_vector);
603  }
604 
605  float64_t total_weight=gain_weights.sum(gain_weights.vector,gain_weights.vlen);
606 
607  SGVector<float64_t> attr_val_unique=gain_attribute_values.clone();
608  int32_t uniques_num=attr_val_unique.unique(attr_val_unique.vector,attr_val_unique.vlen);
609 
610  for (int32_t i=0; i<uniques_num; i++)
611  {
612  //calculate class entropy for the specific attribute_value
613  int32_t attr_count=0;
614  float64_t weight_count=0.;
615 
616  for (int32_t j=0; j<num_vecs; j++)
617  {
618  if (gain_attribute_values[j]==attr_val_unique[i])
619  {
620  weight_count+=gain_weights[j];
621  attr_count++;
622  }
623  }
624 
625  SGVector<float64_t> sub_class(attr_count);
626  SGVector<float64_t> sub_weights(attr_count);
627  int32_t count=0;
628 
629  for (int32_t j=0; j<num_vecs; j++)
630  {
631  if (gain_attribute_values[j]==attr_val_unique[i])
632  {
633  sub_weights[count]=gain_weights[j];
634  sub_class[count++]=gain_labels->get_label(j);
635  }
636  }
637 
638  CMulticlassLabels* sub_labels=new CMulticlassLabels(sub_class);
639  float64_t sub_entropy=entropy(sub_labels,sub_weights);
640  gain += sub_entropy*weight_count/total_weight;
641 
642  SG_UNREF(sub_labels);
643  }
644 
645  float64_t data_entropy=entropy(gain_labels,gain_weights);
646  gain = data_entropy-gain;
647 
648  if (num_missing!=0)
649  {
650  gain*=(num_vecs-0.f)/(num_vecs+num_missing-0.f);
651  SG_UNREF(gain_labels);
652  }
653 
654  return gain;
655 }
656 
657 float64_t CC45ClassifierTree::entropy(CMulticlassLabels* labels, SGVector<float64_t> weights)
658 {
659  SGVector<float64_t> log_ratios(labels->get_unique_labels().size());
660  float64_t total_weight=weights.sum(weights.vector,weights.vlen);
661 
662  for (int32_t i=0;i<labels->get_unique_labels().size();i++)
663  {
664  int32_t count=0;
665  float64_t weight_count=0.;
666  for (int32_t j=0;j<labels->get_num_labels();j++)
667  {
668  if (labels->get_unique_labels()[i]==labels->get_label(j))
669  {
670  weight_count+=weights[j];
671  count++;
672  }
673  }
674 
675  log_ratios[i]=weight_count/total_weight;
676  log_ratios[i]=CMath::log(log_ratios[i]);
677  }
678 
679  return CStatistics::entropy(log_ratios.vector,log_ratios.vlen);
680 }
681 
682 CMulticlassLabels* CC45ClassifierTree::apply_multiclass_from_current_node(CDenseFeatures<float64_t>* feats,
683  node_t* current, bool set_certainty)
684 {
685  REQUIRE(feats, "Features should not be NULL")
686  REQUIRE(current, "Current node should not be NULL")
687 
688  int32_t num_vecs=feats->get_num_vectors();
689  SGVector<float64_t> labels(num_vecs);
690  if (set_certainty)
691  m_certainty=SGVector<float64_t>(num_vecs);
692 
693  // classify vectors in feature matrix taking one at a time
694  for (int32_t i=0; i<num_vecs; i++)
695  {
696  // choose the current subtree as the entry point
697  SGVector<float64_t> sample=feats->get_feature_vector(i);
698  node_t* node=current;
699  SG_REF(node);
700  CDynamicObjectArray* children=node->get_children();
701 
702  // traverse the subtree until leaf node is reached
703  while (children->get_num_elements())
704  {
705  bool flag=false;
706  // if nominal attribute check for equality
707  if (m_nominal[node->data.attribute_id])
708  {
709  for (int32_t j=0; j<children->get_num_elements(); j++)
710  {
711  CSGObject* el=children->get_element(j);
712  node_t* child=NULL;
713  if (el!=NULL)
714  child=dynamic_cast<node_t*>(el);
715  else
716  SG_ERROR("%d element of children is NULL\n",j);
717 
718  if (child->data.transit_if_feature_value==sample[node->data.attribute_id])
719  {
720  flag=true;
721 
722  SG_UNREF(node);
723  node=child;
724 
725  SG_UNREF(children);
726  children=node->get_children();
727 
728  break;
729  }
730 
731  SG_UNREF(child);
732  }
733 
734  if (!flag)
735  break;
736  }
737  // if not nominal attribute check if greater or less than threshold
738  else
739  {
740  CSGObject* el=children->get_element(0);
741  node_t* left_child=NULL;
742  if (el!=NULL)
743  left_child=dynamic_cast<node_t*>(el);
744  else
745  SG_ERROR("left child is NULL\n")
746 
747  el=children->get_element(1);
748  node_t* right_child=NULL;
749  if (el!=NULL)
750  right_child=dynamic_cast<node_t*>(el);
751  else
752  SG_ERROR("left child is NULL\n")
753 
754  if (left_child->data.transit_if_feature_value>=sample[node->data.attribute_id])
755  {
756  SG_UNREF(node);
757  node=left_child;
758  SG_REF(left_child)
759 
760  SG_UNREF(children);
761  children=node->get_children();
762  }
763  else
764  {
765  SG_UNREF(node);
766  node=right_child;
767  SG_REF(right_child)
768 
769  SG_UNREF(children);
770  children=node->get_children();
771  }
772 
773  SG_UNREF(left_child);
774  SG_UNREF(right_child);
775  }
776  }
777 
778  // class_label of leaf node is the class to which chosen vector belongs
779  labels[i]=node->data.class_label;
780 
781  if (set_certainty)
782  m_certainty[i]=(node->data.total_weight-node->data.weight_minus)/node->data.total_weight;
783 
784  SG_UNREF(node);
785  SG_UNREF(children);
786  }
787 
788  CMulticlassLabels* ret=new CMulticlassLabels(labels);
789  return ret;
790 }
791 
792 void CC45ClassifierTree::init()
793 {
794  m_nominal=SGVector<bool>();
795  m_weights=SGVector<float64_t>();
796  m_certainty=SGVector<float64_t>();
797  m_types_set=false;
798  m_weights_set=false;
799 
800  SG_ADD(&m_nominal,"m_nominal", "feature types", MS_NOT_AVAILABLE);
801  SG_ADD(&m_weights,"m_weights", "weights", MS_NOT_AVAILABLE);
802  SG_ADD(&m_certainty,"m_certainty", "certainty", MS_NOT_AVAILABLE);
803  SG_ADD(&m_weights_set,"m_weights_set", "weights set", MS_NOT_AVAILABLE);
804  SG_ADD(&m_types_set,"m_types_set", "feature types set", MS_NOT_AVAILABLE);
805 }
806 

SHOGUN Machine Learning Toolbox - Documentation