SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
VwRegressor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
3  * embodied in the content of this file are licensed under the BSD
4  * (revised) open source license.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Written (W) 2011 Shashwat Lal Das
12  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
13  */
14 
17 #include <shogun/io/IOBuffer.h>
18 
19 using namespace shogun;
20 
22  : CSGObject()
23 {
24  weight_vectors = NULL;
25  loss = new CSquaredLoss();
26  init(NULL);
27 }
28 
30  : CSGObject()
31 {
32  weight_vectors = NULL;
33  loss = new CSquaredLoss();
34  init(env_to_use);
35 }
36 
38 {
39  // TODO: the number of weight_vectors depends on num_threads
40  // this should be reimplemented using SGVector (for reference counting)
41  SG_FREE(weight_vectors);
42  SG_UNREF(loss);
43  SG_UNREF(env);
44 }
45 
46 void CVwRegressor::init(CVwEnvironment* env_to_use)
47 {
48  if (!env_to_use)
49  env_to_use = new CVwEnvironment();
50 
51  env = env_to_use;
52  SG_REF(env);
53 
54  // For each feature, there should be 'stride' number of
55  // elements in the weight vector
56  vw_size_t length = ((vw_size_t) 1) << env->num_bits;
57  env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
58 
59  // Only one learning thread for now
60  vw_size_t num_threads = 1;
61  weight_vectors = SG_MALLOC(float32_t*, num_threads);
62 
63  for (vw_size_t i = 0; i < num_threads; i++)
64  {
65  weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
66 
67  if (env->random_weights)
68  {
69  for (vw_size_t j = 0; j < length/num_threads; j++)
70  weight_vectors[i][j] = CMath::random(-0.5, 0.5);
71  }
72 
73  if (env->initial_weight != 0.)
74  for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
76 
77  if (env->adaptive)
78  for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
79  weight_vectors[i][j] = 1;
80  }
81 }
82 
83 // TODO: remove this, as we have serialization FW
84 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
85 {
86  CIOBuffer io_temp;
87  int32_t f = io_temp.open_file(reg_name,'w');
88 
89  if (f < 0)
90  SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name)
91 
92  const char* vw_version = env->vw_version;
93  vw_size_t v_length = env->v_length;
94 
95  if (!as_text)
96  {
97  // Write version info
98  io_temp.write_file((char*)&v_length, sizeof(v_length));
99  io_temp.write_file(vw_version,v_length);
100 
101  // Write max and min labels
102  io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
103  io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
104 
105  // Write weight vector bits information
106  io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
107  io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
108 
109  // For paired namespaces forming quadratic features
110  int32_t len = env->pairs.get_num_elements();
111  io_temp.write_file((char *)&len, sizeof(len));
112 
113  for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
114  io_temp.write_file(env->pairs.get_element(k), 2);
115 
116  // ngram and skips information
117  io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
118  io_temp.write_file((char*)&env->skips, sizeof(env->skips));
119  }
120  else
121  {
122  // Write as human readable form
123  char buff[512];
124  int32_t len;
125 
126  len = sprintf(buff, "Version %s\n", vw_version);
127  io_temp.write_file(buff, len);
128  len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
129  io_temp.write_file(buff, len);
130  len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
131  io_temp.write_file(buff, len);
132 
133  if (env->pairs.get_num_elements() > 0)
134  {
135  len = sprintf(buff, "\n");
136  io_temp.write_file(buff, len);
137  }
138 
139  len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
140  io_temp.write_file(buff, len);
141  }
142 
143  uint32_t length = 1 << env->num_bits;
144  vw_size_t num_threads = env->num_threads();
145  vw_size_t stride = env->stride;
146 
147  // Write individual weights
148  for(uint32_t i = 0; i < length; i++)
149  {
150  float32_t v;
151  v = weight_vectors[i%num_threads][stride*(i/num_threads)];
152  if (v != 0.)
153  {
154  if (!as_text)
155  {
156  io_temp.write_file((char *)&i, sizeof (i));
157  io_temp.write_file((char *)&v, sizeof (v));
158  }
159  else
160  {
161  char buff[512];
162  int32_t len = sprintf(buff, "%d:%f\n", i, v);
163  io_temp.write_file(buff, len);
164  }
165  }
166  }
167 
168  io_temp.close_file();
169 }
170 
171 // TODO: remove this, as we have serialization FW
173 {
174  CIOBuffer source;
175  int32_t fd = source.open_file(file, 'r');
176 
177  if (fd < 0)
178  SG_SERROR("Unable to open file for loading regressor!\n")
179 
180  // Read version info
181  vw_size_t v_length;
182  source.read_file((char*)&v_length, sizeof(v_length));
183  char* t = SG_MALLOC(char, v_length);
184  source.read_file(t,v_length);
185  if (strcmp(t,env->vw_version) != 0)
186  {
187  SG_FREE(t);
188  SG_SERROR("Regressor source has an incompatible VW version!\n")
189  }
190  SG_FREE(t);
191 
192  // Read min and max label
193  source.read_file((char*)&env->min_label, sizeof(env->min_label));
194  source.read_file((char*)&env->max_label, sizeof(env->max_label));
195 
196  // Read num_bits, multiple sources are not supported
197  vw_size_t local_num_bits;
198  source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
199 
200  if ((vw_size_t) env->num_bits != local_num_bits)
201  SG_SERROR("Wrong number of bits in regressor source!\n")
202 
203  env->num_bits = local_num_bits;
204 
205  vw_size_t local_thread_bits;
206  source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
207 
208  env->thread_bits = local_thread_bits;
209 
210  int32_t len;
211  source.read_file((char *)&len, sizeof(len));
212 
213  // Read paired namespace information
214  DynArray<char*> local_pairs;
215  for (; len > 0; len--)
216  {
217  char pair[3];
218  source.read_file(pair, sizeof(char)*2);
219  pair[2]='\0';
220  local_pairs.push_back(pair);
221  }
222 
223  env->pairs = local_pairs;
224 
225  // Initialize the weight vector
226  if (weight_vectors)
227  SG_FREE(weight_vectors);
228  init(env);
229 
230  vw_size_t local_ngram;
231  source.read_file((char*)&local_ngram, sizeof(local_ngram));
232  vw_size_t local_skips;
233  source.read_file((char*)&local_skips, sizeof(local_skips));
234 
235  env->ngram = local_ngram;
236  env->skips = local_skips;
237 
238  // Read individual weights
239  vw_size_t stride = env->stride;
240  while (true)
241  {
242  uint32_t hash;
243  ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
244  if (hash_bytes <= 0)
245  break;
246 
247  float32_t w = 0.;
248  ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
249  if (weight_bytes <= 0)
250  break;
251 
252  vw_size_t num_threads = env->num_threads();
253 
254  weight_vectors[hash % num_threads][(hash*stride)/num_threads]
255  = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
256  }
257  source.close_file();
258 }

SHOGUN Machine Learning Toolbox - Documentation