Commit d329acf4 authored by Tiago Peixoto's avatar Tiago Peixoto

uncertain_blockmodel: implement set_state()

parent bf3c8e36
......@@ -101,6 +101,18 @@ void export_uncertain_state()
no_init);
c.def("remove_edge", &state_t::remove_edge)
.def("add_edge", &state_t::add_edge)
.def("set_state",
+[](state_t& state, GraphInterface& gi,
boost::any aw)
{
typedef eprop_map_t<int32_t>::type emap_t;
auto w = any_cast<emap_t>(aw).get_unchecked();
gt_dispatch<>()
([&](auto& g)
{ set_state(state, g, w); },
all_graph_views())
(gi.get_graph_view());
})
.def("remove_edge_dS", &state_t::remove_edge_dS)
.def("add_edge_dS", &state_t::add_edge_dS)
.def("entropy", &state_t::entropy)
......
......@@ -96,6 +96,40 @@ void get_xedges_prob(State& state, python::object edges, python::object probs,
(es.shape()[1] > 2) ? es[i][2] : 0);
}
template <class State, class Graph, class EProp>
void set_state(State& state, Graph& u, EProp w)
{
std::vector<std::pair<size_t, size_t>> us;
for (auto v : vertices_range(state._u))
{
us.clear();
for (auto e : out_edges_range(v, state._u))
{
auto w = target(e, state._u);
if (w == v)
continue;
us.emplace_back(w, state._eweight[e]);
}
for (auto& uw : us)
{
for (size_t i = 0; i < uw.second; ++i)
state.remove_edge(v, uw.first);
}
auto& e = state.template get_u_edge<false>(v, v);
if (e == state._null_edge)
continue;
size_t x = state._eweight[e];
for (size_t i = 0; i < x; ++i)
state.remove_edge(v, v);
}
for (auto e : edges_range(u))
{
for (size_t i = 0; i < size_t(w[e]); ++i)
state.add_edge(source(e, u), target(e, u));
}
}
} // graph_tool namespace
......
......@@ -109,6 +109,21 @@ class UncertainBaseState(object):
init_q_cache()
def __getstate__(self):
self.u.ep.w = self.eweight
u = self.u.copy()
eweight = u.ep.w
del u.ep["w"]
del self.u.ep["w"]
return dict(g=self.g, nested=self.nbstate is not None,
bstate=(self.nbstate.copy(g=u, state_args=dict(eweight=eweight))
if self.nbstate is not None else
self.bstate.copy(g=u, eweight=eweight)),
self_loops=self.self_loops)
def __setstate__(self, state):
self.__init__(**state)
def get_block_state(self):
"""Return the underlying block state, which can be either
:class:`~graph_tool.inference.blockmodel.BlockState` or
......@@ -150,6 +165,11 @@ class UncertainBaseState(object):
entropy_args = get_uentropy_args(dentropy_args)
return self._state.add_edge_dS(int(u), int(v), entropy_args)
def set_state(self, g, w):
if w.value_type() != "int32_t":
w = w.copy("int32_t")
self._state.set_state(g._Graph__graph, w._get_any())
def _algo_sweep(self, algo, r=.5, **kwargs):
kwargs = kwargs.copy()
beta = kwargs.get("beta", 1.)
......@@ -403,13 +423,9 @@ class UncertainBlockState(UncertainBaseState):
self._state = libinference.make_uncertain_state(self.bstate._state,
self)
def __getstate__(self):
return dict(g=self.g, q=self._q, q_default=self._q_default,
aE=self.aE, nested=self.nbstate is not None,
bstate=(self.nbstate.copy() if self.nbstate is not None else
self.bstate.copy()), self_loops=self.self_loops)
def __setstate__(self, state):
self.__init__(**state)
state = super(UncertainBlockState, self).__getstate__()
return dict(state, q=self._q, q_default=self._q_default,
aE=self.aE)
def copy(self, **kwargs):
"""Return a copy of the state."""
......@@ -482,9 +498,8 @@ class LatentMultigraphBlockState(UncertainBaseState):
self._state = libinference.make_uncertain_state(self.bstate._state,
self)
def __getstate__(self):
return dict(g=self.g, aE=self.aE, nested=self.nbstate is not None,
bstate=(self.nbstate.copy() if self.nbstate is not None else
self.bstate.copy()), self_loops=self.self_loops)
state = super(LatentMultigraphBlockState, self).__getstate__()
return dict(state, aE=self.aE)
def __setstate__(self, state):
self.__init__(**state)
......@@ -584,16 +599,11 @@ class MeasuredBlockState(UncertainBaseState):
self)
def __getstate__(self):
return dict(g=self.g, n=self.n, x=self.x, n_default=self.n_default,
state = super(MeasuredBlockState, self).__getstate__()
return dict(state, n=self.n, x=self.x, n_default=self.n_default,
x_default=self.x_default,
fn_params=dict(alpha=self.alpha, beta=self.beta),
fp_params=dict(mu=self.mu, nu=self.nu), aE=self.aE,
nested=self.nbstate is not None,
bstate=(self.nbstate if self.nbstate is not None
else self.bstate), self_loops=self.self_loops)
def __setstate__(self, state):
self.__init__(**state)
fp_params=dict(mu=self.mu, nu=self.nu), aE=self.aE)
def copy(self, **kwargs):
"""Return a copy of the state."""
......@@ -739,16 +749,11 @@ class MixedMeasuredBlockState(UncertainBaseState):
self.sync_q()
def __getstate__(self):
return dict(g=self.g, n=self.n, x=self.x, n_default=self.n_default,
state = super(MixedMeasuredBlockState, self).__getstate__()
return dict(state, n=self.n, x=self.x, n_default=self.n_default,
x_default=self.x_default,
fn_params=dict(alpha=self.alpha, beta=self.beta),
fp_params=dict(mu=self.mu, nu=self.nu), aE=self.aE,
nested=self.nbstate is not None,
bstate=(self.nbstate if self.nbstate is not None
else self.bstate), self_loops=self.self_loops)
def __setstate__(self, state):
self.__init__(**state)
fp_params=dict(mu=self.mu, nu=self.nu), aE=self.aE)
def copy(self, **kwargs):
"""Return a copy of the state."""
......@@ -869,15 +874,10 @@ class DynamicsBlockStateBase(UncertainBaseState):
self._state.set_params(self.params)
def __getstate__(self):
return dict(g=self.g, s=self.s, t=self.t, x=self.x, aE=self.aE,
nested=self.nbstate is not None,
bstate=(self.nbstate.copy() if self.nbstate is not None else
self.bstate.copy()), self_loops=self.self_loops,
state = super(DynamicsBlockState, self).__getstate__()
return dict(state, s=self.s, t=self.t, x=self.x, aE=self.aE,
**self.params)
def __setstate__(self, state):
self.__init__(**state)
def copy(self, **kwargs):
"""Return a copy of the state."""
return type(self)(**dict(self.__getstate__(), **kwargs))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment