SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
VwRegressor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
3  * embodied in the content of this file are licensed under the BSD
4  * (revised) open source license.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Written (W) 2011 Shashwat Lal Das
12  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
13  */
14 
17 #include <shogun/io/IOBuffer.h>
18 
19 using namespace shogun;
20 
22  : CSGObject()
23 {
24  weight_vectors = NULL;
25  loss = new CSquaredLoss();
26  init(NULL);
27 }
28 
30  : CSGObject()
31 {
32  weight_vectors = NULL;
33  loss = new CSquaredLoss();
34  init(env_to_use);
35 }
36 
38 {
39  // TODO: the number of weight_vectors depends on num_threads
40  // this should be reimplemented using SGVector (for reference counting)
41  if (weight_vectors)
42  {
43  vw_size_t num_threads = 1;
44  for (vw_size_t i = 0; i < num_threads; i++)
45  {
46  SG_FREE(weight_vectors[i]);
47  }
48  }
49 
50  SG_FREE(weight_vectors);
51  SG_UNREF(loss);
52  SG_UNREF(env);
53 }
54 
55 void CVwRegressor::init(CVwEnvironment* env_to_use)
56 {
57  if (!env_to_use)
58  env_to_use = new CVwEnvironment();
59 
60  env = env_to_use;
61  SG_REF(env);
62 
63  // For each feature, there should be 'stride' number of
64  // elements in the weight vector
65  vw_size_t length = ((vw_size_t) 1) << env->num_bits;
66  env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
67 
68  // Only one learning thread for now
69  vw_size_t num_threads = 1;
70  weight_vectors = SG_MALLOC(float32_t*, num_threads);
71 
72  for (vw_size_t i = 0; i < num_threads; i++)
73  {
74  weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
75 
76  if (env->random_weights)
77  {
78  for (vw_size_t j = 0; j < length/num_threads; j++)
79  weight_vectors[i][j] = CMath::random(-0.5, 0.5);
80  }
81 
82  if (env->initial_weight != 0.)
83  for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
85 
86  if (env->adaptive)
87  for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
88  weight_vectors[i][j] = 1;
89  }
90 }
91 
92 // TODO: remove this, as we have serialization FW
93 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
94 {
95  CIOBuffer io_temp;
96  int32_t f = io_temp.open_file(reg_name,'w');
97 
98  if (f < 0)
99  SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name)
100 
101  const char* vw_version = env->vw_version;
102  vw_size_t v_length = env->v_length;
103 
104  if (!as_text)
105  {
106  // Write version info
107  io_temp.write_file((char*)&v_length, sizeof(v_length));
108  io_temp.write_file(vw_version,v_length);
109 
110  // Write max and min labels
111  io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
112  io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
113 
114  // Write weight vector bits information
115  io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
116  io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
117 
118  // For paired namespaces forming quadratic features
119  int32_t len = env->pairs.get_num_elements();
120  io_temp.write_file((char *)&len, sizeof(len));
121 
122  for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
123  io_temp.write_file(env->pairs.get_element(k), 2);
124 
125  // ngram and skips information
126  io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
127  io_temp.write_file((char*)&env->skips, sizeof(env->skips));
128  }
129  else
130  {
131  // Write as human readable form
132  char buff[512];
133  int32_t len;
134 
135  len = sprintf(buff, "Version %s\n", vw_version);
136  io_temp.write_file(buff, len);
137  len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
138  io_temp.write_file(buff, len);
139  len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
140  io_temp.write_file(buff, len);
141 
142  if (env->pairs.get_num_elements() > 0)
143  {
144  len = sprintf(buff, "\n");
145  io_temp.write_file(buff, len);
146  }
147 
148  len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
149  io_temp.write_file(buff, len);
150  }
151 
152  uint32_t length = 1 << env->num_bits;
153  vw_size_t num_threads = env->num_threads();
154  vw_size_t stride = env->stride;
155 
156  // Write individual weights
157  for(uint32_t i = 0; i < length; i++)
158  {
159  float32_t v;
160  v = weight_vectors[i%num_threads][stride*(i/num_threads)];
161  if (v != 0.)
162  {
163  if (!as_text)
164  {
165  io_temp.write_file((char *)&i, sizeof (i));
166  io_temp.write_file((char *)&v, sizeof (v));
167  }
168  else
169  {
170  char buff[512];
171  int32_t len = sprintf(buff, "%d:%f\n", i, v);
172  io_temp.write_file(buff, len);
173  }
174  }
175  }
176 
177  io_temp.close_file();
178 }
179 
180 // TODO: remove this, as we have serialization FW
182 {
183  CIOBuffer source;
184  int32_t fd = source.open_file(file, 'r');
185 
186  if (fd < 0)
187  SG_SERROR("Unable to open file for loading regressor!\n")
188 
189  // Read version info
190  vw_size_t v_length;
191  source.read_file((char*)&v_length, sizeof(v_length));
192  char* t = SG_MALLOC(char, v_length);
193  source.read_file(t,v_length);
194  if (strcmp(t,env->vw_version) != 0)
195  {
196  SG_FREE(t);
197  SG_SERROR("Regressor source has an incompatible VW version!\n")
198  }
199  SG_FREE(t);
200 
201  // Read min and max label
202  source.read_file((char*)&env->min_label, sizeof(env->min_label));
203  source.read_file((char*)&env->max_label, sizeof(env->max_label));
204 
205  // Read num_bits, multiple sources are not supported
206  vw_size_t local_num_bits;
207  source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
208 
209  if ((vw_size_t) env->num_bits != local_num_bits)
210  SG_SERROR("Wrong number of bits in regressor source!\n")
211 
212  env->num_bits = local_num_bits;
213 
214  vw_size_t local_thread_bits;
215  source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
216 
217  env->thread_bits = local_thread_bits;
218 
219  int32_t len;
220  source.read_file((char *)&len, sizeof(len));
221 
222  // Read paired namespace information
223  DynArray<char*> local_pairs;
224  for (; len > 0; len--)
225  {
226  char pair[3];
227  source.read_file(pair, sizeof(char)*2);
228  pair[2]='\0';
229  local_pairs.push_back(pair);
230  }
231 
232  env->pairs = local_pairs;
233 
234  // Initialize the weight vector
235  if (weight_vectors)
236  SG_FREE(weight_vectors);
237  init(env);
238 
239  vw_size_t local_ngram;
240  source.read_file((char*)&local_ngram, sizeof(local_ngram));
241  vw_size_t local_skips;
242  source.read_file((char*)&local_skips, sizeof(local_skips));
243 
244  env->ngram = local_ngram;
245  env->skips = local_skips;
246 
247  // Read individual weights
248  vw_size_t stride = env->stride;
249  while (true)
250  {
251  uint32_t hash;
252  ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
253  if (hash_bytes <= 0)
254  break;
255 
256  float32_t w = 0.;
257  ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
258  if (weight_bytes <= 0)
259  break;
260 
261  vw_size_t num_threads = env->num_threads();
262 
263  weight_vectors[hash % num_threads][(hash*stride)/num_threads]
264  = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
265  }
266  source.close_file();
267 }

SHOGUN Machine Learning Toolbox - Documentation