SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CustomDistance.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #ifndef _CUSTOMDISTANCE_H___
12 #define _CUSTOMDISTANCE_H___
13 
14 #include <shogun/lib/config.h>
15 
17 #include <shogun/lib/common.h>
20 
21 namespace shogun
22 {
32 {
33  public:
36 
43 
47  CCustomDistance(const SGMatrix<float64_t> distance_matrix);
48 
60  const float64_t* dm, int32_t rows, int32_t cols);
61 
73  const float32_t* dm, int32_t rows, int32_t cols);
74 
75  virtual ~CCustomDistance();
76 
87  virtual bool dummy_init(int32_t rows, int32_t cols);
88 
95  virtual bool init(CFeatures* l, CFeatures* r);
96 
98  virtual void cleanup();
99 
105 
110  virtual EFeatureType get_feature_type() { return F_ANY; }
111 
116  virtual EFeatureClass get_feature_class() { return C_ANY; }
117 
122  virtual const char* get_name() const { return "CustomDistance"; }
123 
135  const float64_t* dm, int32_t len)
136  {
138  }
139 
151  const float32_t* dm, int32_t len)
152  {
154  }
155 
166  template <class T>
168  const T* dm, int64_t len)
169  {
170  ASSERT(dm)
171  ASSERT(len>0)
172 
173  int64_t cols = (int64_t) floor(-0.5 + CMath::sqrt(0.25+2*len));
174 
175  int64_t int32_max=2147483647;
176 
177  if (cols> int32_max)
178  SG_ERROR("Matrix larger than %d x %d\n", int32_max)
179 
180  if (cols*(cols+1)/2 != len)
181  {
182  SG_ERROR("dm should be a vector containing a lower triangle matrix, with len=cols*(cols+1)/2 elements\n")
183  return false;
184  }
185 
186  cleanup_custom();
187  SG_DEBUG("using custom distance of size %dx%d\n", cols,cols)
188 
189  dmatrix= SG_MALLOC(float32_t, len);
190 
191  upper_diagonal=true;
192  num_rows=cols;
193  num_cols=cols;
194 
195  for (int64_t i=0; i<len; i++)
196  dmatrix[i]=dm[i];
197 
199  return true;
200  }
201 
213  const float64_t* dm, int32_t rows, int32_t cols)
214  {
215  return set_triangle_distance_matrix_from_full_generic(dm, rows, cols);
216  }
217 
229  const float32_t* dm, int32_t rows, int32_t cols)
230  {
231  return set_triangle_distance_matrix_from_full_generic(dm, rows, cols);
232  }
233 
242  template <class T>
244  const T* dm, int32_t rows, int32_t cols)
245  {
246  ASSERT(rows==cols)
247 
248  cleanup_custom();
249  SG_DEBUG("using custom distance of size %dx%d\n", cols,cols)
250 
251  dmatrix= SG_MALLOC(float32_t, int64_t(cols)*(cols+1)/2);
252 
253  upper_diagonal=true;
254  num_rows=cols;
255  num_cols=cols;
256 
257  for (int64_t row=0; row<num_rows; row++)
258  {
259  for (int64_t col=row; col<num_cols; col++)
260  {
261  int64_t idx=row * num_cols - row*(row+1)/2 + col;
262  dmatrix[idx]= (float32_t) dm[col*num_rows+row];
263  }
264  }
265  dummy_init(rows, cols);
266  return true;
267  }
268 
279  const float64_t* dm, int32_t rows, int32_t cols)
280  {
281  return set_full_distance_matrix_from_full_generic(dm, rows, cols);
282  }
283 
294  const float32_t* dm, int32_t rows, int32_t cols)
295  {
296  return set_full_distance_matrix_from_full_generic(dm, rows, cols);
297  }
298 
306  template <class T>
307  bool set_full_distance_matrix_from_full_generic(const T* dm, int32_t rows, int32_t cols)
308  {
309  cleanup_custom();
310  SG_DEBUG("using custom distance of size %dx%d\n", rows,cols)
311 
312  dmatrix=SG_MALLOC(float32_t, rows*cols);
313 
314  upper_diagonal=false;
315  num_rows=rows;
316  num_cols=cols;
317 
318  for (int32_t row=0; row<num_rows; row++)
319  {
320  for (int32_t col=0; col<num_cols; col++)
321  {
322  dmatrix[row * num_cols + col]=dm[col*num_rows+row];
323  }
324  }
325 
326  dummy_init(rows, cols);
327  return true;
328  }
329 
334  virtual int32_t get_num_vec_lhs()
335  {
336  return num_rows;
337  }
338 
343  virtual int32_t get_num_vec_rhs()
344  {
345  return num_cols;
346  }
347 
352  virtual bool has_features()
353  {
354  return (num_rows>0) && (num_cols>0);
355  }
356 
357  protected:
364  virtual float64_t compute(int32_t row, int32_t col);
365 
366  private:
367  void init();
368 
370  void cleanup_custom();
371 
372  protected:
376  int32_t num_rows;
378  int32_t num_cols;
381 };
382 
383 }
384 #endif /* _CUSTOMKERNEL_H__ */

SHOGUN Machine Learning Toolbox - Documentation