SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MultitaskLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
13 #include <vector>
14 
15 namespace shogun
16 {
17 
20 {
21  initialize_parameters();
22  register_parameters();
23 }
24 
26  float64_t z, CDotFeatures* train_features,
27  CBinaryLabels* train_labels, CTaskRelation* task_relation) :
28  CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
29 {
30  initialize_parameters();
31  register_parameters();
32  set_z(z);
33 }
34 
36 {
37 }
38 
39 void CMultitaskLogisticRegression::register_parameters()
40 {
41  SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
42  SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
43  SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
44  SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
45  SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
46  SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
47 }
48 
49 void CMultitaskLogisticRegression::initialize_parameters()
50 {
51  set_z(0.0);
52  set_q(2.0);
53  set_termination(0);
55  set_tolerance(1e-3);
56  set_max_iter(1000);
57 }
58 
60 {
61  if (data && (CDotFeatures*)data)
62  set_features((CDotFeatures*)data);
63 
66 
68  for (int32_t i=0; i<y.vlen; i++)
69  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
70 
71  slep_options options = slep_options::default_options();
72  options.n_tasks = m_task_relation->get_num_tasks();
73  options.tasks_indices = m_task_relation->get_tasks_indices();
74  options.q = m_q;
75  options.regularization = m_regularization;
76  options.termination = m_termination;
77  options.tolerance = m_tolerance;
78  options.max_iter = m_max_iter;
79 
80  ETaskRelationType relation_type = m_task_relation->get_relation_type();
81  switch (relation_type)
82  {
83  case TASK_GROUP:
84  {
85  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
86  options.mode = MULTITASK_GROUP;
87  options.loss = LOGISTIC;
88  slep_result_t result = slep_solver(features, y.vector, m_z, options);
89  m_tasks_w = result.w;
90  m_tasks_c = result.c;
91  }
92  break;
93  case TASK_TREE:
94  {
95  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
96  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
97  options.ind_t = ind_t.vector;
98  options.n_nodes = ind_t.vlen / 3;
99  options.mode = MULTITASK_TREE;
100  options.loss = LOGISTIC;
101  slep_result_t result = slep_solver(features, y.vector, m_z, options);
102  m_tasks_w = result.w;
103  m_tasks_c = result.c;
104  }
105  break;
106  default:
107  SG_ERROR("Not supported task relation type\n")
108  }
109  SG_FREE(options.tasks_indices);
110 
111  return true;
112 }
113 
115 {
118 
120  for (int32_t i=0; i<y.vlen; i++)
121  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
122 
123  slep_options options = slep_options::default_options();
124  options.n_tasks = m_task_relation->get_num_tasks();
125  options.tasks_indices = tasks;
126  options.q = m_q;
127  options.regularization = m_regularization;
128  options.termination = m_termination;
129  options.tolerance = m_tolerance;
130  options.max_iter = m_max_iter;
131 
132  ETaskRelationType relation_type = m_task_relation->get_relation_type();
133  switch (relation_type)
134  {
135  case TASK_GROUP:
136  {
137  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
138  options.mode = MULTITASK_GROUP;
139  options.loss = LOGISTIC;
140  slep_result_t result = slep_solver(features, y.vector, m_z, options);
141  m_tasks_w = result.w;
142  m_tasks_c = result.c;
143  }
144  break;
145  case TASK_TREE:
146  {
147  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
148  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
149  options.ind_t = ind_t.vector;
150  options.n_nodes = ind_t.vlen / 3;
151  options.mode = MULTITASK_TREE;
152  options.loss = LOGISTIC;
153  slep_result_t result = slep_solver(features, y.vector, m_z, options);
154  m_tasks_w = result.w;
155  m_tasks_c = result.c;
156  }
157  break;
158  default:
159  SG_ERROR("Not supported task relation type\n")
160  }
161  return true;
162 }
163 
165 {
167  //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
168  //return 2.0/(1.0+ep) - 1.0;
169  return dot + m_tasks_c[m_current_task];
170 }
171 
173 {
174  return m_max_iter;
175 }
177 {
178  return m_regularization;
179 }
181 {
182  return m_termination;
183 }
185 {
186  return m_tolerance;
187 }
189 {
190  return m_z;
191 }
193 {
194  return m_q;
195 }
196 
198 {
199  ASSERT(max_iter>=0)
200  m_max_iter = max_iter;
201 }
203 {
204  ASSERT(regularization==0 || regularization==1)
205  m_regularization = regularization;
206 }
208 {
209  ASSERT(termination>=0 && termination<=4)
210  m_termination = termination;
211 }
213 {
214  ASSERT(tolerance>0.0)
215  m_tolerance = tolerance;
216 }
218 {
219  m_z = z;
220 }
222 {
223  m_q = q;
224 }
225 
226 }

SHOGUN Machine Learning Toolbox - Documentation