Commit 2f049ff6 authored by Tiago Peixoto's avatar Tiago Peixoto

Fix computation and sampling from marginal multigraph

parent f2d4b756
......@@ -72,9 +72,15 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae
double marginal_multigraph_sample(GraphInterface& gi, boost::any axs,
boost::any axc, boost::any ax, rng_t& rng);
double marginal_multigraph_lprob(GraphInterface& gi, boost::any axs,
boost::any axc, boost::any ax);
double marginal_graph_sample(GraphInterface& gi, boost::any ap,
boost::any ax, rng_t& rng);
double marginal_graph_lprob(GraphInterface& gi, boost::any ap,
boost::any ax);
void export_uncertain_state()
{
using namespace boost::python;
......@@ -142,5 +148,7 @@ void export_uncertain_state()
def("collect_marginal_count", &collect_marginal_count_dispatch);
def("marginal_count_entropy", &marginal_count_entropy);
def("marginal_multigraph_sample", &marginal_multigraph_sample);
def("marginal_multigraph_lprob", &marginal_multigraph_lprob);
def("marginal_graph_sample", &marginal_graph_sample);
def("marginal_graph_lprob", &marginal_graph_lprob);
}
......@@ -17,6 +17,7 @@
#include "graph_tool.hh"
#include "graph_blockmodel_uncertain_marginal.hh"
#include "parallel_rng.hh"
using namespace boost;
using namespace graph_tool;
......@@ -116,8 +117,32 @@ double marginal_count_entropy(GraphInterface& gi, boost::any aexc, boost::any ae
return S_tot;
}
double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any axc,
boost::any ax, rng_t& rng)
void marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any axc,
boost::any ax, rng_t& rng_)
{
gt_dispatch<>()
([&](auto& g, auto& xs, auto& xc, auto& x)
{
parallel_rng<rng_t>::init(rng_);
parallel_edge_loop
(g,
[&](auto& e)
{
typedef std::remove_reference_t<decltype(xs[e][0])> val_t;
std::vector<double> probs(xc[e].begin(), xc[e].end());
Sampler<val_t> sample(xs[e], probs);
auto& rng = parallel_rng<rng_t>::get(rng_);
x[e] = sample.sample(rng);
});
},
all_graph_views(), edge_scalar_vector_properties(),
edge_scalar_vector_properties(), writable_edge_scalar_properties())
(gi.get_graph_view(), axs, axc, ax);
}
double marginal_multigraph_lprob(GraphInterface& gi, boost::any axs, boost::any axc,
boost::any ax)
{
double L = 0;
gt_dispatch<>()
......@@ -125,31 +150,52 @@ double marginal_multigraph_sample(GraphInterface& gi, boost::any axs, boost::any
{
for (auto e : edges_range(g))
{
typedef std::remove_reference_t<decltype(xs[e][0])> val_t;
std::vector<double> probs(xc[e].begin(), xc[e].end());
Sampler<val_t> sample(xs[e], probs);
x[e] = sample.sample(rng);
size_t Z = 0;
size_t p = 0;
for (size_t i = 0; i < xs[e].size(); ++i)
{
auto m = xs[e][i];
if (m == x[e])
size_t m = xs[e][i];
if (m == size_t(x[e]))
p = xc[e][i];
Z += xc[e][i];
}
if (p == 0)
{
L = -numeric_limits<double>::infinity();
break;
}
L += std::log(p) - std::log(Z);
}
},
all_graph_views(), edge_scalar_vector_properties(),
edge_scalar_vector_properties(), writable_edge_scalar_properties())
edge_scalar_vector_properties(), edge_scalar_properties())
(gi.get_graph_view(), axs, axc, ax);
return L;
}
double marginal_graph_sample(GraphInterface& gi, boost::any ap,
boost::any ax, rng_t& rng)
void marginal_graph_sample(GraphInterface& gi, boost::any ap,
boost::any ax, rng_t& rng_)
{
gt_dispatch<>()
([&](auto& g, auto& p, auto& x)
{
parallel_rng<rng_t>::init(rng_);
parallel_edge_loop
(g,
[&](auto& e)
{
std::bernoulli_distribution sample(p[e]);
auto& rng = parallel_rng<rng_t>::get(rng_);
x[e] = sample(rng);
});
},
all_graph_views(), edge_scalar_properties(),
writable_edge_scalar_properties())
(gi.get_graph_view(), ap, ax);
}
double marginal_graph_lprob(GraphInterface& gi, boost::any ap,
boost::any ax)
{
double L = 0;
gt_dispatch<>()
......@@ -157,8 +203,6 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap,
{
for (auto e : edges_range(g))
{
std::bernoulli_distribution sample(p[e]);
x[e] = sample(rng);
if (x[e] == 1)
L += std::log(p[e]);
else
......@@ -166,7 +210,7 @@ double marginal_graph_sample(GraphInterface& gi, boost::any ap,
}
},
all_graph_views(), edge_scalar_properties(),
writable_edge_scalar_properties())
edge_scalar_properties())
(gi.get_graph_view(), ap, ax);
return L;
}
......@@ -189,6 +189,8 @@ __all__ = ["minimize_blockmodel_dl",
"marginal_multigraph_entropy",
"marginal_multigraph_sample",
"marginal_graph_sample",
"marginal_multigraph_lprob",
"marginal_graph_lprob",
"PartitionHist",
"BlockPairHist",
"half_edge_graph",
......
......@@ -1519,20 +1519,34 @@ def marginal_multigraph_entropy(g, ecount):
_prop("e", g, eh))
return eh
def marginal_multigraph_sample(g, ew, ecount):
def marginal_multigraph_sample(g, ews, ecount):
w = g.new_ep("int")
L = libinference.marginal_multigraph_sample(g._Graph__graph,
_prop("e", g, ew),
libinference.marginal_multigraph_sample(g._Graph__graph,
_prop("e", g, ews),
_prop("e", g, ecount),
_prop("e", g, w),
_get_rng())
return w, L
return w
def marginal_multigraph_lprob(g, ews, ecount, ew):
L = libinference.marginal_multigraph_lprob(g._Graph__graph,
_prop("e", g, ews),
_prop("e", g, ecount),
_prop("e", g, ew))
return L
def marginal_graph_sample(g, ep):
w = g.new_ep("int")
L = libinference.marginal_graph_sample(g._Graph__graph,
_prop("e", g, ep),
_prop("e", g, w),
_get_rng())
return w, L
libinference.marginal_graph_sample(g._Graph__graph,
_prop("e", g, ep),
_prop("e", g, w),
_get_rng())
return w
def marginal_graph_lprob(g, ep, w):
L = libinference.marginal_graph_lprob(g._Graph__graph,
_prop("e", g, ep),
_prop("e", g, w))
return L
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment