SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LinearARDKernel.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2015 Wu Lin
8  * Written (W) 2012 Jacob Walker
9  *
10  * Adapted from WeightedDegreeRBFKernel.cpp
11  */
12 
14 
15 #ifdef HAVE_LINALG_LIB
17 #endif
18 
19 using namespace shogun;
20 
22 {
23  initialize();
24 }
25 
27 {
29 }
30 
31 void CLinearARDKernel::initialize()
32 {
35  m_weights.set_const(1.0);
36  SG_ADD(&m_weights, "weights", "Feature weights", MS_AVAILABLE,
38  SG_ADD((int *)(&m_ARD_type), "type", "ARD kernel type", MS_NOT_AVAILABLE);
39 }
40 
41 #ifdef HAVE_LINALG_LIB
43 {
44  initialize();
45 }
46 
48  CDotFeatures* r, int32_t size) : CDotKernel(size)
49 {
50  initialize();
51  init(l,r);
52 }
53 
54 bool CLinearARDKernel::init(CFeatures* l, CFeatures* r)
55 {
56  cleanup();
57  CDotKernel::init(l, r);
58  int32_t dim=((CDotFeatures*) l)->get_dim_feature_space();
59  if (m_ARD_type==KT_FULL)
60  {
61  REQUIRE(m_weights.num_cols==dim, "Dimension mismatch between features (%d) and weights (%d)\n",
62  dim, m_weights.num_cols);
63  }
64  else if (m_ARD_type==KT_DIAG)
65  {
66  REQUIRE(m_weights.num_rows==dim, "Dimension mismatch between features (%d) and weights (%d)\n",
67  dim, m_weights.num_rows);
68  }
69  return init_normalizer();
70 }
71 
72 
73 SGMatrix<float64_t> CLinearARDKernel::compute_right_product(SGVector<float64_t>right_vec,
74  float64_t & scalar_weight)
75 {
76  SGMatrix<float64_t> right;
77 
78  if (m_ARD_type==KT_SCALAR)
79  {
80  right=SGMatrix<float64_t>(right_vec.vector,right_vec.vlen,1,false);
81  scalar_weight*=m_weights[0];
82  }
83  else
84  {
85  SGMatrix<float64_t> rtmp(right_vec.vector,right_vec.vlen,1,false);
86 
87  if(m_ARD_type==KT_DIAG)
88  right=linalg::elementwise_product(m_weights, rtmp);
89  else if(m_ARD_type==KT_FULL)
90  right=linalg::matrix_product(m_weights, rtmp);
91  else
92  SG_ERROR("Unsupported ARD type\n");
93  }
94  return right;
95 }
96 
97 float64_t CLinearARDKernel::compute_helper(SGVector<float64_t> avec, SGVector<float64_t>bvec)
98 {
100  SGMatrix<float64_t> left_transpose;
101  float64_t scalar_weight=1.0;
102  if (m_ARD_type==KT_SCALAR)
103  {
104  left=SGMatrix<float64_t>(avec.vector,1,avec.vlen,false);
105  scalar_weight=m_weights[0];
106  }
107  else
108  {
109  SGMatrix<float64_t> ltmp(avec.vector,avec.vlen,1,false);
110  if(m_ARD_type==KT_DIAG)
111  left_transpose=linalg::elementwise_product(m_weights, ltmp);
112  else if(m_ARD_type==KT_FULL)
113  left_transpose=linalg::matrix_product(m_weights, ltmp);
114  else
115  SG_ERROR("Unsupported ARD type\n");
116  left=SGMatrix<float64_t>(left_transpose.matrix,1,left_transpose.num_rows,false);
117  }
118  SGMatrix<float64_t> right=compute_right_product(bvec, scalar_weight);
119  SGMatrix<float64_t> res=linalg::matrix_product(left, right);
120  return res[0]*scalar_weight;
121 }
122 
123 float64_t CLinearARDKernel::compute(int32_t idx_a, int32_t idx_b)
124 {
125  REQUIRE(lhs, "Left features not set!\n");
126  REQUIRE(rhs, "Right features not set!\n");
127  SGVector<float64_t> avec=((CDotFeatures *)lhs)->get_computed_dot_feature_vector(idx_a);
128  SGVector<float64_t> bvec=((CDotFeatures *)rhs)->get_computed_dot_feature_vector(idx_b);
129 
130  return compute_helper(avec, bvec);
131 }
132 
133 float64_t CLinearARDKernel::compute_gradient_helper(SGVector<float64_t> avec,
135 {
136  float64_t result;
137 
138  if(m_ARD_type==KT_DIAG)
139  {
140  result=2.0*avec[index]*bvec[index]*m_weights[index];
141  }
142  else
143  {
144  SGMatrix<float64_t> left(avec.vector,1,avec.vlen,false);
145  SGMatrix<float64_t> right(bvec.vector,bvec.vlen,1,false);
147 
148  if (m_ARD_type==KT_SCALAR)
149  {
150  res=linalg::matrix_product(left, right);
151  result=2.0*res[0]*m_weights[0];
152  }
153  else if(m_ARD_type==KT_FULL)
154  {
155  int32_t row_index=index%m_weights.num_rows;
156  int32_t col_index=index/m_weights.num_rows;
157  //index is a linearized index of m_weights (column-major)
158  //m_weights is a d-by-p matrix, where p is #dimension of features
159  SGVector<float64_t> row_vec=m_weights.get_row_vector(row_index);
160  SGMatrix<float64_t> row_vec_r(row_vec.vector,row_vec.vlen,1,false);
161 
162  res=linalg::matrix_product(left, row_vec_r);
163  result=res[0]*bvec[col_index];
164 
165  SGMatrix<float64_t> row_vec_l(row_vec.vector,1,row_vec.vlen,false);
166  res=linalg::matrix_product(row_vec_l, right);
167  result+=res[0]*avec[col_index];
168 
169  }
170  else
171  {
172  SG_ERROR("Unsupported ARD type\n");
173  }
174 
175  }
176  return result*scale;
177 }
178 
179 
181  const TParameter* param, index_t index)
182 {
183  REQUIRE(lhs, "Left features not set!\n");
184  REQUIRE(rhs, "Right features not set!\n");
185 
186  int32_t row_index, col_index;
187  if (m_ARD_type!=KT_SCALAR)
188  {
189  REQUIRE(index>=0, "Index (%d) must be non-negative\n",index);
190  if (m_ARD_type==KT_DIAG)
191  {
192  REQUIRE(index<m_weights.num_rows, "Index (%d) must be within #dimension of weights (%d)\n",
193  index, m_weights.num_rows);
194  }
195  else if(m_ARD_type==KT_FULL)
196  {
197  row_index=index%m_weights.num_rows;
198  col_index=index/m_weights.num_rows;
199  REQUIRE(row_index<m_weights.num_rows,
200  "Row index (%d) must be within #row of weights (%d)\n",
201  row_index, m_weights.num_rows);
202  REQUIRE(col_index<m_weights.num_cols,
203  "Column index (%d) must be within #column of weights (%d)\n",
204  col_index, m_weights.num_cols);
205  }
206  }
207  if (!strcmp(param->m_name, "weights"))
208  {
209  SGMatrix<float64_t> derivative(num_lhs, num_rhs);
210 
211  for (index_t j=0; j<num_lhs; j++)
212  {
213  SGVector<float64_t> avec=((CDotFeatures *)lhs)->get_computed_dot_feature_vector(j);
214  for (index_t k=0; k<num_rhs; k++)
215  {
216  SGVector<float64_t> bvec=((CDotFeatures *)rhs)->get_computed_dot_feature_vector(k);
217  derivative(j,k)=compute_gradient_helper(avec, bvec, 1.0, index);
218  }
219  }
220  return derivative;
221  }
222  else
223  {
224  SG_ERROR("Can't compute derivative wrt %s parameter\n", param->m_name);
225  return SGMatrix<float64_t>();
226  }
227 }
228 
229 SGMatrix<float64_t> CLinearARDKernel::get_weights()
230 {
232 }
233 
234 void CLinearARDKernel::set_weights(SGMatrix<float64_t> weights)
235 {
236  REQUIRE(weights.num_cols>0 && weights.num_rows>0,
237  "Weight Matrix (%d-by-%d) must not be empty\n",
238  weights.num_rows, weights.num_cols);
239  if (weights.num_cols>1)
240  {
242  }
243  else
244  {
245  if (weights.num_rows==1)
246  {
248  }
249  else
250  {
252  }
253  }
254  m_weights=weights;
255 }
256 
257 void CLinearARDKernel::set_scalar_weights(float64_t weight)
258 {
259  SGMatrix<float64_t> weights(1,1);
260  weights(0,0)=weight;
261  set_weights(weights);
262 }
263 
264 void CLinearARDKernel::set_vector_weights(SGVector<float64_t> weights)
265 {
266  SGMatrix<float64_t> weights_mat(weights.vlen,1);
267  std::copy(weights.vector, weights.vector+weights.vlen, weights_mat.matrix);
268  set_weights(weights_mat);
269 }
270 
271 void CLinearARDKernel::set_matrix_weights(SGMatrix<float64_t> weights)
272 {
273  set_weights(weights);
274 }
275 #endif //HAVE_LINALG_LIB

SHOGUN Machine Learning Toolbox - Documentation