Commit 9d71c105 authored by Tiago Peixoto's avatar Tiago Peixoto

inference: Fix multiflip MCMC with NestedBlockState

parent 3d854565
...@@ -287,8 +287,15 @@ public: ...@@ -287,8 +287,15 @@ public:
} }
} }
bool allow_move(size_t r, size_t nr, bool allow_empty = true) bool allow_move(size_t v, size_t r, size_t nr, bool allow_empty = true)
{ {
if (_coupled_state != nullptr && is_last(v))
{
auto& bh = _coupled_state->get_b();
if (bh[r] != bh[nr])
return false;
}
if (allow_empty) if (allow_empty)
return ((_bclabel[r] == _bclabel[nr]) || (_wr[nr] == 0)); return ((_bclabel[r] == _bclabel[nr]) || (_wr[nr] == 0));
else else
...@@ -301,7 +308,7 @@ public: ...@@ -301,7 +308,7 @@ public:
if (r == nr) if (r == nr)
return; return;
if (!allow_move(r, nr)) if (!allow_move(v, r, nr))
throw ValueException("cannot move vertex across clabel barriers"); throw ValueException("cannot move vertex across clabel barriers");
get_move_entries(v, r, nr, _m_entries, std::forward<EFilt>(efilt)); get_move_entries(v, r, nr, _m_entries, std::forward<EFilt>(efilt));
...@@ -1405,7 +1412,7 @@ public: ...@@ -1405,7 +1412,7 @@ public:
{ {
assert(size_t(_b[v]) == r || r == null_group); assert(size_t(_b[v]) == r || r == null_group);
if (r != null_group && nr != null_group && !allow_move(r, nr)) if (r != null_group && nr != null_group && !allow_move(v, r, nr))
return std::numeric_limits<double>::infinity(); return std::numeric_limits<double>::infinity();
get_move_entries(v, r, nr, m_entries, [](auto) { return false; }); get_move_entries(v, r, nr, m_entries, [](auto) { return false; });
...@@ -1760,14 +1767,7 @@ public: ...@@ -1760,14 +1767,7 @@ public:
if (_coupled_state != nullptr) if (_coupled_state != nullptr)
{ {
auto& hb = _coupled_state->get_b(); auto& hb = _coupled_state->get_b();
size_t t; hb[s] = hb[_b[v]];
do
{
t = _coupled_state->sample_block(_b[v], c, d, rng);
}
while (!_coupled_state->allow_move(hb[_b[v]], t,
_allow_empty));
hb[s] = t;
} }
return s; return s;
} }
......
...@@ -48,7 +48,7 @@ struct entropy_args_t ...@@ -48,7 +48,7 @@ struct entropy_args_t
bool edges_dl; bool edges_dl;
bool recs_dl; bool recs_dl;
double beta_dl; double beta_dl;
double Bfield; bool Bfield;
}; };
// Sparse entropy terms // Sparse entropy terms
......
...@@ -102,7 +102,7 @@ struct Gibbs ...@@ -102,7 +102,7 @@ struct Gibbs
nr = _state._empty_blocks.back(); nr = _state._empty_blocks.back();
} }
size_t r = _state._b[v]; size_t r = _state._b[v];
if (!_state.allow_move(r, nr)) if (!_state.allow_move(v, r, nr))
return numeric_limits<double>::infinity(); return numeric_limits<double>::infinity();
return _state.virtual_move(v, r, nr, _entropy_args, _m_entries); return _state.virtual_move(v, r, nr, _entropy_args, _m_entries);
} }
......
...@@ -106,7 +106,7 @@ struct MCMC ...@@ -106,7 +106,7 @@ struct MCMC
size_t s = _state.sample_block(v, _c, _d, rng); size_t s = _state.sample_block(v, _c, _d, rng);
if (!_state.allow_move(r, s)) if (!_state.allow_move(v, r, s))
return null_group; return null_group;
return s; return s;
......
...@@ -113,7 +113,7 @@ struct Merge ...@@ -113,7 +113,7 @@ struct Merge
s = uniform_sample(_available, rng); s = uniform_sample(_available, rng);
} }
if (s == r || !_state.allow_move(r, s, false)) if (s == r || !_state.allow_move(v, r, s, false))
return _null_move; return _null_move;
return s; return s;
......
...@@ -89,17 +89,17 @@ struct MCMC ...@@ -89,17 +89,17 @@ struct MCMC
_entropy_args.edges_dl)); _entropy_args.edges_dl));
for (auto v : vertices_range(_state._g)) for (auto v : vertices_range(_state._g))
{ {
if (_state._vweight[v] > 0) if (_state._vweight[v] == 0)
{ continue;
add_element(_groups[_state._b[v]], _vpos, v); add_element(_groups[_state._b[v]], _vpos, v);
_N++; _N++;
}
} }
for (auto r : vertices_range(_state._bg)) for (auto r : vertices_range(_state._bg))
{ {
if (_state._wr[r] > 0) if (_state._wr[r] == 0)
add_element(_vlist, _rpos, r); continue;
add_element(_vlist, _rpos, r);
} }
} }
...@@ -129,7 +129,6 @@ struct MCMC ...@@ -129,7 +129,6 @@ struct MCMC
size_t _t = null_group; size_t _t = null_group;
size_t _u = null_group; size_t _u = null_group;
size_t _v = null_group; size_t _v = null_group;
std::vector<size_t> _mschanged;
size_t node_state(size_t r) size_t node_state(size_t r)
{ {
...@@ -155,13 +154,16 @@ struct MCMC ...@@ -155,13 +154,16 @@ struct MCMC
_groups.resize(t + 1); _groups.resize(t + 1);
_rpos.resize(t + 1); _rpos.resize(t + 1);
} }
assert(_state._wr[t] == 0); assert(_state._wr[t] == 0);
return t; return t;
} }
void move_vertex(size_t v, size_t r) void move_vertex(size_t v, size_t r)
{ {
auto s = _state._b[v]; size_t s = _state._b[v];
if (s == r)
return;
remove_element(_groups[s], _vpos, v); remove_element(_groups[s], _vpos, v);
_state.move_vertex(v, r); _state.move_vertex(v, r);
add_element(_groups[r], _vpos, v); add_element(_groups[r], _vpos, v);
...@@ -197,7 +199,11 @@ struct MCMC ...@@ -197,7 +199,11 @@ struct MCMC
if (rt[1] == null_group) if (rt[1] == null_group)
{ {
rt[1] = sample_new_group(v, rng); if (forward)
rt[1] = sample_new_group(v, rng);
else
rt[1] = (_state.virtual_remove_size(v) == 0) ?
r : sample_new_group(v, rng);
dS += _state.virtual_move(v, _state._b[v], rt[1], dS += _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args); _entropy_args);
if (forward) if (forward)
...@@ -260,7 +266,7 @@ struct MCMC ...@@ -260,7 +266,7 @@ struct MCMC
else else
ddS = std::numeric_limits<double>::infinity(); ddS = std::numeric_limits<double>::infinity();
if (!std::isinf(_beta)) if (!std::isinf(_beta) && !std::isinf(ddS))
{ {
double Z = log_sum(0., -ddS * _beta); double Z = log_sum(0., -ddS * _beta);
p[0] = -ddS * _beta - Z; p[0] = -ddS * _beta - Z;
...@@ -303,7 +309,8 @@ struct MCMC ...@@ -303,7 +309,8 @@ struct MCMC
template <class RNG> template <class RNG>
double split_prob(size_t r, size_t s, RNG& rng) double split_prob(size_t r, size_t s, RNG& rng)
{ {
size_t t = sample_new_group(_vs.front(), rng); size_t t = (_state._wr[r] == 0) ? r : s;
for (auto v : _vs) for (auto v : _vs)
{ {
_btemp[v] = _state._b[v]; _btemp[v] = _state._b[v];
...@@ -366,13 +373,16 @@ struct MCMC ...@@ -366,13 +373,16 @@ struct MCMC
else else
ddS = std::numeric_limits<double>::infinity(); ddS = std::numeric_limits<double>::infinity();
double Z = log_sum(0., -ddS * _beta); if (!std::isinf(ddS))
ddS *= _beta;
double Z = log_sum(0., -ddS);
double p; double p;
if ((size_t(_bprev[v]) == r) == (nbv == r_)) if ((size_t(_bprev[v]) == r) == (nbv == r_))
{ {
_state.move_vertex(v, nbv); _state.move_vertex(v, nbv);
p = -ddS * _beta - Z; p = -ddS - Z;
} }
else else
{ {
...@@ -401,6 +411,17 @@ struct MCMC ...@@ -401,6 +411,17 @@ struct MCMC
return lp; return lp;
} }
bool allow_merge(size_t r, size_t s)
{
if (_state._coupled_state != nullptr)
{
auto& bh = _state._coupled_state->get_b();
if (bh[r] != bh[s])
return false;
}
return _state._bclabel[r] == _state._bclabel[s];
}
template <class RNG, bool symmetric=true> template <class RNG, bool symmetric=true>
std::tuple<size_t, double> merge(size_t r, size_t s, RNG& rng) std::tuple<size_t, double> merge(size_t r, size_t s, RNG& rng)
...@@ -481,7 +502,7 @@ struct MCMC ...@@ -481,7 +502,7 @@ struct MCMC
_groups.resize(_s + 1); _groups.resize(_s + 1);
_rpos.resize(_s + 1); _rpos.resize(_s + 1);
} }
if (r == _s || !_state.allow_move(r, _s)) if (r == _s || !_state.allow_move(v, r, _s))
return _null_move; return _null_move;
_dS = _state.virtual_move(v, r, _s, _entropy_args); _dS = _state.virtual_move(v, r, _s, _entropy_args);
if (!std::isinf(_beta)) if (!std::isinf(_beta))
...@@ -527,7 +548,7 @@ struct MCMC ...@@ -527,7 +548,7 @@ struct MCMC
size_t v = uniform_sample(_groups[r], rng); size_t v = uniform_sample(_groups[r], rng);
_s = _state.sample_block(v, _c, 0, rng); _s = _state.sample_block(v, _c, 0, rng);
} }
if (!_state.allow_move(r, _s)) if (!allow_merge(r, _s))
{ {
_nproposals += _groups[r].size() + _groups[_s].size(); _nproposals += _groups[r].size() + _groups[_s].size();
return _null_move; return _null_move;
...@@ -551,6 +572,7 @@ struct MCMC ...@@ -551,6 +572,7 @@ struct MCMC
_a = 0; _a = 0;
_nproposals += _vs.size(); _nproposals += _vs.size();
return move; return move;
} }
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
entropy_args_t& ea) = 0; entropy_args_t& ea) = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_b() = 0; virtual vprop_map_t<int32_t>::type::unchecked_t& get_b() = 0;
virtual bool check_edge_counts(bool emat=true) = 0; virtual bool check_edge_counts(bool emat=true) = 0;
virtual bool allow_move(size_t r, size_t nr, bool allow_empty = true) = 0; virtual bool allow_move(size_t v, size_t r, size_t nr, bool allow_empty = true) = 0;
}; };
} // graph_tool namespace } // graph_tool namespace
......
...@@ -487,9 +487,9 @@ struct Layers ...@@ -487,9 +487,9 @@ struct Layers
set_partition(b.get_unchecked()); set_partition(b.get_unchecked());
} }
bool allow_move(size_t r, size_t nr, bool allow_empty = true) bool allow_move(size_t v, size_t r, size_t nr, bool allow_empty = true)
{ {
return BaseState::allow_move(r, nr, allow_empty); return BaseState::allow_move(v, r, nr, allow_empty);
} }
template <class MEntries> template <class MEntries>
...@@ -707,7 +707,7 @@ struct Layers ...@@ -707,7 +707,7 @@ struct Layers
entropy_args_t mea = {false, false, false, false, true, entropy_args_t mea = {false, false, false, false, true,
false, false, false, false, false, false,
ea.degree_dl_kind, false, ea.recs_dl, ea.degree_dl_kind, false, ea.recs_dl,
ea.beta_dl}; ea.beta_dl, false};
for (auto& state : _layers) for (auto& state : _layers)
S += state.entropy(mea); S += state.entropy(mea);
} }
......
...@@ -74,7 +74,7 @@ auto mcmc_sweep(MCMCState state, RNG& rng) ...@@ -74,7 +74,7 @@ auto mcmc_sweep(MCMCState state, RNG& rng)
for (size_t vi = 0; vi < vlist.size(); ++vi) for (size_t vi = 0; vi < vlist.size(); ++vi)
{ {
auto&& v = (state.is_sequential()) ? auto v = (state.is_sequential()) ?
vlist[vi] : uniform_sample(vlist, rng); vlist[vi] : uniform_sample(vlist, rng);
if (state.skip_node(v)) if (state.skip_node(v))
......
...@@ -206,14 +206,14 @@ public: ...@@ -206,14 +206,14 @@ public:
modify_vertex<true>(v, r); modify_vertex<true>(v, r);
} }
bool allow_move(size_t r, size_t nr) bool allow_move(size_t, size_t r, size_t nr)
{ {
return (_bclabel[r] == _bclabel[nr]); return (_bclabel[r] == _bclabel[nr]);
} }
bool allow_move(size_t r, size_t nr, bool) bool allow_move(size_t v, size_t r, size_t nr, bool)
{ {
return allow_move(r, nr); return allow_move(v, r, nr);
} }
// move a vertex from its current block to block nr // move a vertex from its current block to block nr
...@@ -224,7 +224,7 @@ public: ...@@ -224,7 +224,7 @@ public:
if (r == nr) if (r == nr)
return; return;
if (!allow_move(r, nr)) if (!allow_move(v, r, nr))
throw ValueException("cannot move vertex across clabel barriers"); throw ValueException("cannot move vertex across clabel barriers");
bool r_vacate = (_overlap_stats.virtual_remove_size(v, r) == 0); bool r_vacate = (_overlap_stats.virtual_remove_size(v, r) == 0);
...@@ -393,7 +393,7 @@ public: ...@@ -393,7 +393,7 @@ public:
return 0; return 0;
} }
if (!allow_move(r, nr)) if (!allow_move(v, r, nr))
return std::numeric_limits<double>::infinity(); return std::numeric_limits<double>::infinity();
get_move_entries(v, r, nr, m_entries); get_move_entries(v, r, nr, m_entries);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment