SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
BeliefPropagation.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
11 #ifndef __BELIEF_PROPAGATION_H__
12 #define __BELIEF_PROPAGATION_H__
13 
14 #include <shogun/lib/config.h>
15 
16 #include <shogun/lib/SGVector.h>
19 
20 #include <vector>
21 #include <set>
22 
23 #ifdef HAVE_STD_UNORDERED_MAP
24  #include <unordered_map>
25 #else
26  #include <tr1/unordered_map>
27 #endif
28 
29 #ifndef DOXYGEN_SHOULD_SKIP_THIS
30 
31 namespace shogun
32 {
33 #define IGNORE_IN_CLASSLIST
34 
35 enum ENodeType
36 {
37  VAR_NODE = 0,
38  FAC_NODE = 1
39 };
40 
41 enum EEdgeType
42 {
43  VAR_TO_FAC = 0,
44  FAC_TO_VAR = 1
45 };
46 
47 struct GraphNode
48 {
49  int32_t node_id;
50  ENodeType node_type; // 1 var, 0 factor
51  int32_t parent; // where came from
52 
53  GraphNode(int32_t id, ENodeType type, int32_t pa)
54  : node_id(id), node_type(type), parent(pa) { }
55  ~GraphNode() { }
56 };
57 
58 struct MessageEdge
59 {
60  EEdgeType mtype; // 1 var_to_factor, 0 factor_to_var
61  int32_t child;
62  int32_t parent;
63 
64  MessageEdge(EEdgeType type, int32_t ch, int32_t pa)
65  : mtype(type), child(ch), parent(pa) { }
66 
67  ~MessageEdge() { }
68 
69  inline int32_t get_var_node()
70  {
71  return mtype == VAR_TO_FAC ? child : parent;
72  }
73 
74  inline int32_t get_factor_node()
75  {
76  return mtype == VAR_TO_FAC ? parent : child;
77  }
78 };
79 
81 IGNORE_IN_CLASSLIST class CBeliefPropagation : public CMAPInferImpl
82 {
83 public:
84  CBeliefPropagation();
85  CBeliefPropagation(CFactorGraph* fg);
86 
87  virtual ~CBeliefPropagation();
88 
90  virtual const char* get_name() const { return "BeliefPropagation"; }
91 
92  virtual float64_t inference(SGVector<int32_t> assignment);
93 
94 protected:
95  float64_t m_map_energy;
96 };
97 
106 IGNORE_IN_CLASSLIST class CTreeMaxProduct : public CBeliefPropagation
107 {
108 #ifdef HAVE_STD_UNORDERED_MAP
109  typedef std::unordered_map<uint32_t, uint32_t> msg_map_type;
110  typedef std::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type;
111  typedef std::unordered_multimap<int32_t, int32_t> var_factor_map_type;
112 #else
113  typedef std::tr1::unordered_map<uint32_t, uint32_t> msg_map_type;
114  typedef std::tr1::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type;
115  typedef std::tr1::unordered_multimap<int32_t, int32_t> var_factor_map_type;
116 #endif
117 
118 public:
119  CTreeMaxProduct();
120  CTreeMaxProduct(CFactorGraph* fg);
121 
122  virtual ~CTreeMaxProduct();
123 
125  virtual const char* get_name() const { return "TreeMaxProduct"; }
126 
127  virtual float64_t inference(SGVector<int32_t> assignment);
128 
129 protected:
130  void bottom_up_pass();
131  void top_down_pass();
132  void get_message_order(std::vector<MessageEdge*>& order, std::vector<bool>& is_root) const;
133 
134 private:
135  void init();
136 
137 private:
138  std::vector<MessageEdge*> m_msg_order;
139  std::vector<bool> m_is_root;
140  std::vector< std::vector<float64_t> > m_fw_msgs;
141  std::vector< std::vector<float64_t> > m_bw_msgs;
142  std::vector<int32_t> m_states;
143 
144  msg_map_type m_msg_map_var;
145  msg_map_type m_msg_map_fac;
146  msgset_map_type m_msgset_map_var;
147 };
148 
149 }
150 
151 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
152 
153 #endif

SHOGUN Machine Learning Toolbox - Documentation