Commit c08bbf86 authored by Tiago Peixoto's avatar Tiago Peixoto

inference.blockmodel: Move UncertainBaseState.get_edge_prob() to C++

This also implements UncertainBaseState.get_edges_prob().
parent 60e235e0
......@@ -23,6 +23,7 @@
#include "../blockmodel/graph_blockmodel.hh"
#define BASE_STATE_params BLOCK_STATE_params
#include "graph_blockmodel_uncertain.hh"
#include "graph_blockmodel_measured.hh"
#include "../support/graph_state.hh"
......@@ -85,7 +86,22 @@ void export_measured_state()
.def("get_N", &state_t::get_N)
.def("get_X", &state_t::get_X)
.def("get_T", &state_t::get_T)
.def("get_M", &state_t::get_M);
.def("get_M", &state_t::get_M)
.def("get_edge_prob",
+[](state_t& state, size_t u, size_t v,
entropy_args_t ea, double epsilon)
{
return get_edge_prob(state, u, v, ea,
epsilon);
})
.def("get_edges_prob",
+[](state_t& state, python::object edges,
python::object probs, entropy_args_t ea,
double epsilon)
{
get_edges_prob(state, edges, probs, ea,
epsilon);
});
});
});
......
......@@ -82,7 +82,23 @@ void export_uncertain_state()
.def("add_edge_dS", &state_t::add_edge_dS)
.def("entropy", &state_t::entropy)
.def("set_q_default", &state_t::set_q_default)
.def("set_S_const", &state_t::set_S_const);
.def("set_S_const", &state_t::set_S_const)
.def("get_edge_prob",
+[](state_t& state, size_t u, size_t v,
entropy_args_t ea, double epsilon)
{
return get_edge_prob(state, u, v, ea,
epsilon);
})
.def("get_edges_prob",
+[](state_t& state, python::object edges,
python::object probs, entropy_args_t ea,
double epsilon)
{
get_edges_prob(state, edges, probs, ea,
epsilon);
});
});
});
......
......@@ -222,6 +222,53 @@ struct Uncertain
};
};
template <class State>
double get_edge_prob(State& state, size_t u, size_t v, entropy_args_t ea,
double epsilon)
{
auto e = state.get_u_edge(u, v);
size_t ew = 0;
if (e != state._null_edge)
ew = state._eweight[e];
for (size_t i = 0; i < ew; ++i)
state.remove_edge(u, v);
double S = 0, Z = 1;
double delta = 1. + epsilon;
size_t ne = 0;
while (delta > epsilon || ne < 2)
{
double dS = state.add_edge_dS(u, v, ea);
state.add_edge(u, v);
S += dS;
ne++;
double dZ = exp(-S);
delta = dZ/Z;
Z += dZ;
}
double L = log1p(-1./Z);
for (int i = 0; i < int(ne - ew); ++i)
state.remove_edge(u, v);
for (int i = 0; i < int(ew - ne); ++i)
state.add_edge(u, v);
return L;
}
template <class State>
void get_edges_prob(State& state, python::object edges, python::object probs,
entropy_args_t ea, double epsilon)
{
multi_array_ref<uint64_t,2> es = get_array<uint64_t,2>(edges);
multi_array_ref<double,1> eprobs = get_array<double,1>(probs);
for (size_t i = 0; i < eprobs.shape()[0]; ++i)
eprobs[i] = get_edge_prob(state, es[i][0], es[i][1], ea, epsilon);
}
} // graph_tool namespace
#endif //GRAPH_BLOCKMODEL_UNCERTAIN_HH
......@@ -161,37 +161,17 @@ class UncertainBaseState(object):
r"""Return conditional posterior probability of edge :math:`(u,v)`."""
entropy_args = dict(self.bstate._entropy_args, **entropy_args)
ea = get_entropy_args(entropy_args)
return self._state.get_edge_prob(u, v, ea, epsilon)
e = self.u.edge(u, v)
if e is None:
ew = 0
else:
ew = self.eweight[e]
for i in range(ew):
self._state.remove_edge(int(u), int(v))
L = 0
delta = 1 + epsilon
ne = 0
S = 0
M = 0
while delta > epsilon:
dS = self._state.add_edge_dS(int(u), int(v), ea)
self._state.add_edge(int(u), int(v))
S += dS
ne += 1
M += log1p(exp(-S-M))
delta = exp(-S-M)
L = log1p(-exp(-M))
for i in range(ne - ew):
self._state.remove_edge(int(u), int(v))
for i in range(ew - ne):
self._state.add_edge(int(u), int(v))
return L
def get_edges_prob(self, elist, entropy_args={}, epsilon=1e-8):
r"""Return conditional posterior probability of an edge list, with
shape :math:`(E,2)`."""
entropy_args = dict(self.bstate._entropy_args, **entropy_args)
ea = get_entropy_args(entropy_args)
elist = numpy.asarray(elist, dtype="uint64")
probs = numpy.zeros(elist.shape[0])
self._state.get_edges_prob(elist, probs, ea, epsilon)
return probs
class UncertainBlockState(UncertainBaseState):
def __init__(self, g, q, q_default=0., phi=numpy.nan, nested=True, state_args={},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment