Commit 2f049ff6 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix computation and sampling from marginal multigraph

parent f2d4b756
...@@ -72,9 +72,15 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae ...@@ -72,9 +72,15 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae
double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, double marginal_multigraph_sample(GraphInterface& gi, boost::any axs,
boost::any axc, boost::any ax, rng_t& rng); boost::any axc, boost::any ax, rng_t& rng);
double marginal_multigraph_lprob(GraphInterface& gi, boost::any axs,
boost::any axc, boost::any ax);
double marginal_graph_sample(GraphInterface& gi, boost::any ap, double marginal_graph_sample(GraphInterface& gi, boost::any ap,
boost::any ax, rng_t& rng); boost::any ax, rng_t& rng);
double marginal_graph_lprob(GraphInterface& gi, boost::any ap,
boost::any ax);
void export_uncertain_state() void export_uncertain_state()
{ {
using namespace boost::python; using namespace boost::python;
...@@ -142,5 +148,7 @@ void export_uncertain_state() ...@@ -142,5 +148,7 @@ void export_uncertain_state()
def("collect_marginal_count", &collect_marginal_count_dispatch); def("collect_marginal_count", &collect_marginal_count_dispatch);
def("marginal_count_entropy", &marginal_count_entropy); def("marginal_count_entropy", &marginal_count_entropy);
def("marginal_multigraph_sample", &marginal_multigraph_sample); def("marginal_multigraph_sample", &marginal_multigraph_sample);
def("marginal_multigraph_lprob", &marginal_multigraph_lprob);
def("marginal_graph_sample", &marginal_graph_sample); def("marginal_graph_sample", &marginal_graph_sample);
def("marginal_graph_lprob", &marginal_graph_lprob);
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "graph_tool.hh" #include "graph_tool.hh"
#include "graph_blockmodel_uncertain_marginal.hh" #include "graph_blockmodel_uncertain_marginal.hh"
#include "parallel_rng.hh"
using namespace boost; using namespace boost;
using namespace graph_tool; using namespace graph_tool;
...@@ -116,8 +117,32 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae ...@@ -116,8 +117,32 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae
return S_tot; return S_tot;
} }
double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any axc, void marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any axc,
boost::any ax, rng_t& rng) boost::any ax, rng_t& rng_)
{
gt_dispatch<>()
([&](auto& g, auto& xs, auto& xc, auto& x)
{
parallel_rng<rng_t>::init(rng_);
parallel_edge_loop
(g,
[&](auto& e)
{
typedef std::remove_reference_t<decltype(xs[e][0])> val_t;
std::vector<double> probs(xc[e].begin(), xc[e].end());
Sampler<val_t> sample(xs[e], probs);
auto& rng = parallel_rng<rng_t>::get(rng_);
x[e] = sample.sample(rng);
});
},
all_graph_views(), edge_scalar_vector_properties(),
edge_scalar_vector_properties(), writable_edge_scalar_properties())
(gi.get_graph_view(), axs, axc, ax);
}
double marginal_multigraph_lprob(GraphInterface& gi, boost::any axs, boost::any axc,
boost::any ax)
{ {
double L = 0; double L = 0;
gt_dispatch<>() gt_dispatch<>()
...@@ -125,31 +150,52 @@ double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any ...@@ -125,31 +150,52 @@ double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any
{ {
for (auto e : edges_range(g)) for (auto e : edges_range(g))
{ {
typedef std::remove_reference_t<decltype(xs[e][0])> val_t;
std::vector<double> probs(xc[e].begin(), xc[e].end());
Sampler<val_t> sample(xs[e], probs);
x[e] = sample.sample(rng);
size_t Z = 0; size_t Z = 0;
size_t p = 0; size_t p = 0;
for (size_t i = 0; i < xs[e].size(); ++i) for (size_t i = 0; i < xs[e].size(); ++i)
{ {
auto m = xs[e][i]; size_t m = xs[e][i];
if (m == x[e]) if (m == size_t(x[e]))
p = xc[e][i]; p = xc[e][i];
Z += xc[e][i]; Z += xc[e][i];
} }
if (p == 0)
{
L = -numeric_limits<double>::infinity();
break;
}
L += std::log(p) - std::log(Z); L += std::log(p) - std::log(Z);
} }
}, },
all_graph_views(), edge_scalar_vector_properties(), all_graph_views(), edge_scalar_vector_properties(),
edge_scalar_vector_properties(), writable_edge_scalar_properties()) edge_scalar_vector_properties(), edge_scalar_properties())
(gi.get_graph_view(), axs, axc, ax); (gi.get_graph_view(), axs, axc, ax);
return L; return L;
} }
double marginal_graph_sample(GraphInterface& gi, boost::any ap, void marginal_graph_sample(GraphInterface& gi, boost::any ap,
boost::any ax, rng_t& rng) boost::any ax, rng_t& rng_)
{
gt_dispatch<>()
([&](auto& g, auto& p, auto& x)
{
parallel_rng<rng_t>::init(rng_);
parallel_edge_loop
(g,
[&](auto& e)
{
std::bernoulli_distribution sample(p[e]);
auto& rng = parallel_rng<rng_t>::get(rng_);
x[e] = sample(rng);
});
},
all_graph_views(), edge_scalar_properties(),
writable_edge_scalar_properties())
(gi.get_graph_view(), ap, ax);
}
double marginal_graph_lprob(GraphInterface& gi, boost::any ap,
boost::any ax)
{ {
double L = 0; double L = 0;
gt_dispatch<>() gt_dispatch<>()
...@@ -157,8 +203,6 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap, ...@@ -157,8 +203,6 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap,
{ {
for (auto e : edges_range(g)) for (auto e : edges_range(g))
{ {
std::bernoulli_distribution sample(p[e]);
x[e] = sample(rng);
if (x[e] == 1) if (x[e] == 1)
L += std::log(p[e]); L += std::log(p[e]);
else else
...@@ -166,7 +210,7 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap, ...@@ -166,7 +210,7 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap,
} }
}, },
all_graph_views(), edge_scalar_properties(), all_graph_views(), edge_scalar_properties(),
writable_edge_scalar_properties()) edge_scalar_properties())
(gi.get_graph_view(), ap, ax); (gi.get_graph_view(), ap, ax);
return L; return L;
} }
...@@ -189,6 +189,8 @@ __all__ = ["minimize_blockmodel_dl", ...@@ -189,6 +189,8 @@ __all__ = ["minimize_blockmodel_dl",
"marginal_multigraph_entropy", "marginal_multigraph_entropy",
"marginal_multigraph_sample", "marginal_multigraph_sample",
"marginal_graph_sample", "marginal_graph_sample",
"marginal_multigraph_lprob",
"marginal_graph_lprob",
"PartitionHist", "PartitionHist",
"BlockPairHist", "BlockPairHist",
"half_edge_graph", "half_edge_graph",
......
...@@ -1519,20 +1519,34 @@ def marginal_multigraph_entropy(g, ecount): ...@@ -1519,20 +1519,34 @@ def marginal_multigraph_entropy(g, ecount):
_prop("e", g, eh)) _prop("e", g, eh))
return eh return eh
def marginal_multigraph_sample(g, ew, ecount): def marginal_multigraph_sample(g, ews, ecount):
w = g.new_ep("int") w = g.new_ep("int")
L = libinference.marginal_multigraph_sample(g._Graph__graph, libinference.marginal_multigraph_sample(g._Graph__graph,
_prop("e", g, ew), _prop("e", g, ews),
_prop("e", g, ecount), _prop("e", g, ecount),
_prop("e", g, w), _prop("e", g, w),
_get_rng()) _get_rng())
return w, L return w
def marginal_multigraph_lprob(g, ews, ecount, ew):
L = libinference.marginal_multigraph_lprob(g._Graph__graph,
_prop("e", g, ews),
_prop("e", g, ecount),
_prop("e", g, ew))
return L
def marginal_graph_sample(g, ep): def marginal_graph_sample(g, ep):
w = g.new_ep("int") w = g.new_ep("int")
L = libinference.marginal_graph_sample(g._Graph__graph, libinference.marginal_graph_sample(g._Graph__graph,
_prop("e", g, ep), _prop("e", g, ep),
_prop("e", g, w), _prop("e", g, w),
_get_rng()) _get_rng())
return w, L return w
def marginal_graph_lprob(g, ep, w):
L = libinference.marginal_graph_lprob(g._Graph__graph,
_prop("e", g, ep),
_prop("e", g, w))
return L
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment