Commit f4d03d03 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

blockmodel: copy branch

parent eefb2c4b
......@@ -72,8 +72,12 @@ typedef mpl::vector1<std::true_type> rmap_tr;
typedef mpl::vector1<std::false_type> rmap_tr;
#endif
#ifndef GRAPH_RANGE
#define GRAPH_RANGE all_graph_views
#endif
#define BLOCK_STATE_params \
((g, &, all_graph_views, 1)) \
((g, &, GRAPH_RANGE, 1)) \
((is_weighted,, mpl::vector1<std::true_type>, 1)) \
((use_hash,, bool_tr, 1)) \
((use_rmap,, rmap_tr, 1)) \
......@@ -1522,6 +1526,30 @@ public:
_b[v] = s;
}
void copy_branch(size_t r, BlockStateVirtualBase& state)
{
if (r >= num_vertices(_bg))
add_block(r - num_vertices(_bg) + 1);
_bclabel[r] = state.get_bclabel()[r];
if (_coupled_state != nullptr)
{
auto& cstate = *state.get_coupled_state();
auto& sbh = cstate.get_b();
auto s = sbh[r];
_coupled_state->copy_branch(s, cstate);
auto& bh = _coupled_state->get_b();
bh[r] = s;
auto& pclabel = cstate.get_pclabel();
auto& hpclabel = _coupled_state->get_pclabel();
hpclabel[r] = pclabel[r];
}
}
// Sample node placement
size_t sample_block(size_t v, double c, double d, rng_t& rng)
{
......@@ -2207,6 +2235,11 @@ public:
_coupled_state = nullptr;
}
BlockStateVirtualBase* get_coupled_state()
{
return _coupled_state;
}
void clear_egroups()
{
_egroups.clear();
......
......@@ -84,7 +84,7 @@ python::object do_gibbs_sweep_parallel(python::object ogibbs_states,
std::vector<std::shared_ptr<gibbs_sweep_base>> sweeps;
size_t N = python::len(ogibbs_states);
for (size_t i = 0; i < N; ++ i)
for (size_t i = 0; i < N; ++i)
{
auto dispatch = [&](auto& block_state)
{
......
......@@ -37,9 +37,11 @@ public:
virtual void remove_partition_node(size_t v, size_t r) = 0;
virtual void set_vertex_weight(size_t v, int w) = 0;
virtual void coupled_resize_vertex(size_t v) = 0;
virtual BlockStateVirtualBase* get_coupled_state() = 0;
virtual double virtual_move(size_t v, size_t r, size_t nr,
const entropy_args_t& eargs) = 0;
virtual void sample_branch(size_t v, size_t u, rng_t& rng) = 0;
virtual void copy_branch(size_t v, BlockStateVirtualBase&) = 0;
virtual size_t sample_block(size_t v, double c, double d, rng_t& rng) = 0;
virtual double get_move_prob(size_t v, size_t r, size_t s, double c, double d,
bool reverse) = 0;
......@@ -67,7 +69,9 @@ public:
const entropy_args_t& ea) = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_b() = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_pclabel() = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_bclabel() = 0;
virtual bool check_edge_counts(bool emat=true) = 0;
virtual void check_node_counts() = 0;
virtual bool allow_move(size_t r, size_t nr) = 0;
virtual void relax_update(bool relax) = 0;
};
......
......@@ -610,6 +610,10 @@ struct Layers
BaseState::sample_branch(v, u, rng);
}
void copy_branch(size_t, BlockStateVirtualBase&)
{
}
double entropy(const entropy_args_t& ea, bool propagate=false)
{
double S = 0, S_dl = 0;
......@@ -795,6 +799,11 @@ struct Layers
state.decouple_state();
}
BlockStateVirtualBase* get_coupled_state()
{
return _lcoupled_state;
}
void couple_state(BlockStateVirtualBase& s,
const entropy_args_t& ea)
{
......@@ -984,6 +993,11 @@ struct Layers
return BaseState::_pclabel;
}
vprop_map_t<int32_t>::type::unchecked_t& get_bclabel()
{
return BaseState::_bclabel;
}
void sync_emat()
{
BaseState::sync_emat();
......@@ -1021,6 +1035,15 @@ struct Layers
return true;
}
void check_node_counts()
{
BaseState::check_node_counts();
for (auto& state : _layers)
state.check_edge_counts();
if (_lcoupled_state != nullptr)
_lcoupled_state->check_node_counts();
}
bool check_layers()
{
for (auto v : vertices_range(_g))
......
......@@ -623,6 +623,10 @@ public:
{
}
void copy_branch(size_t, BlockStateVirtualBase&)
{
}
template <class RNG>
size_t get_lateral_half_edge(size_t v, RNG& rng)
{
......@@ -988,6 +992,11 @@ public:
_coupled_state = nullptr;
}
BlockStateVirtualBase* get_coupled_state()
{
return _coupled_state;
}
void clear_egroups()
{
_egroups.clear();
......@@ -1170,6 +1179,11 @@ public:
return _pclabel;
}
vprop_map_t<int32_t>::type::unchecked_t& get_bclabel()
{
return _bclabel;
}
template <class MCMCState>
void init_mcmc(MCMCState& state)
{
......@@ -1233,6 +1247,12 @@ public:
return true;
}
void check_node_counts()
{
if (_coupled_state != nullptr)
_coupled_state->check_node_counts();
}
void add_partition_node(size_t, size_t) { }
void remove_partition_node(size_t, size_t) { }
void set_vertex_weight(size_t, int) { }
......
......@@ -120,6 +120,9 @@ class NestedBlockState(object):
self._consistency_check()
def _regen_Lrecdx(self, lstate=None):
if not hasattr(self.levels[0], "recdx"):
return
if lstate is None:
levels = self.levels
Lrecdx = self.Lrecdx
......@@ -305,6 +308,9 @@ class NestedBlockState(object):
return S
def _Lrecdx_entropy(self, Lrecdx=None):
if not hasattr(self.levels[0], "recdx"):
return 0
if self.base_type is not LayeredBlockState:
S_D = 0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment