Commit 08dfc58c authored by Tiago Peixoto's avatar Tiago Peixoto

inference: implement coealescence split stage in multiflip_mcmc_sweep()

parent 4eee3d0d
......@@ -1701,9 +1701,9 @@ public:
// Move proposals
// =========================================================================
size_t get_empty_block(size_t v)
size_t get_empty_block(size_t v, bool force_add = false)
{
if (_empty_blocks.empty())
if (_empty_blocks.empty() || force_add)
{
add_block();
auto s = _empty_blocks.back();
......
......@@ -177,11 +177,16 @@ struct MCMC
return false;
}
template <bool sample_branch=true>
size_t sample_new_group(size_t v, rng_t& rng)
template <bool sample_branch=true, class VS = std::array<size_t,0>>
size_t sample_new_group(size_t v, rng_t& rng, VS&& except = VS())
{
_state.get_empty_block(v);
auto t = uniform_sample(_state._empty_blocks, rng);
_state.get_empty_block(v, except.size() >= _state._empty_blocks.size());
size_t t;
do
{
t = uniform_sample(_state._empty_blocks, rng);
} while (!except.empty() &&
std::find(except.begin(), except.end(), t) != except.end());
auto r = _state._b[v];
_state._bclabel[t] = _state._bclabel[r];
......@@ -337,32 +342,112 @@ struct MCMC
std::array<double, 2> ps;
double dS = 0;
if (s != null_group && _groups[s].empty())
_state.move_vertex(_groups[r].front(), s);
std::array<size_t, 2> except = {r, s};
size_t t;
if (_rlist.size() < (forward ? _N - 1 : _N))
t = sample_new_group<false>(_groups[r].front(), rng);
t = sample_new_group<false>(_groups[r].front(), rng, except);
else
t = r;
if (s != null_group && _groups[s].empty())
_state.move_vertex(_groups[r].front(), r);
for (auto v : _groups[r])
{
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
}
if constexpr (!forward)
{
for (auto v : _groups[s])
{
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
}
}
std::shuffle(vs.begin(), vs.end(), rng);
for (auto v : vs)
{
if (rt[0] == null_group)
{
rt[0] = r;
dS += _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
move_vertex(v, rt[0]);
continue;
}
if (rt[1] == null_group)
{
if constexpr (forward)
rt[1] = (s == null_group) ? sample_new_group(v, rng) : s;
else
rt[1] = s;
dS += _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);
move_vertex(v, rt[1]);
continue;
}
ps[0] = _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
ps[1] = _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);;
double Z = log_sum(ps[0], ps[1]);
double p0 = ps[0] - Z;
std::bernoulli_distribution sample(exp(p0));
if (sample(rng))
{
dS += ps[0];
move_vertex(v, rt[0]);
}
else
{
dS += ps[1];
move_vertex(v, rt[1]);
}
}
return {dS, rt[0], rt[1]};
}
template <bool forward=true, class RNG>
std::tuple<double, size_t, size_t>
stage_split_coalesce(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
{
std::array<size_t, 2> rt = {null_group, null_group};
std::array<double, 2> ps;
double dS = 0;
size_t pos = 0;
std::array<size_t, 2> except = {r, s};
for (auto v : _groups[r])
{
size_t t;
if (_rlist.size() + pos < (forward ? _N - 1 : _N))
t = sample_new_group<false>(v, rng, except);
else
t = r;
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
++pos;
}
if constexpr (!forward)
{
for (auto v : _groups[s])
{
size_t t;
if (_rlist.size() + pos < (forward ? _N - 1 : _N))
t = sample_new_group<false>(v, rng, except);
else
t = s;
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
++pos;
}
}
......@@ -420,14 +505,24 @@ struct MCMC
if constexpr (!forward)
vs.insert(vs.end(), _groups[s].begin(), _groups[s].end());
double dS;
std::array<size_t, 2> rt;
double dS = 0;
std::array<size_t, 2> rt = {null_group, null_group};
std::bernoulli_distribution stage_sample(.5);
if (stage_sample(rng))
std::uniform_int_distribution<int> stage_sample(0,2);
switch (stage_sample(rng))
{
case 0:
std::tie(dS, rt[0], rt[1]) = stage_split_random<forward>(vs, r, s, rng);
else
break;
case 1:
std::tie(dS, rt[0], rt[1]) = stage_split_scatter<forward>(vs, r, s, rng);
break;
case 2:
std::tie(dS, rt[0], rt[1]) = stage_split_coalesce<forward>(vs, r, s, rng);
break;
default:
break;
}
for (size_t i = 0; i < _gibbs_sweeps - 1; ++i)
{
......
......@@ -548,9 +548,9 @@ public:
return dS;
}
size_t get_empty_block(size_t v)
size_t get_empty_block(size_t v, bool force_add = true)
{
if (_empty_blocks.empty())
if (_empty_blocks.empty() || force_add)
{
add_block();
auto s = _empty_blocks.back();
......
......@@ -193,7 +193,7 @@ public:
return (Sa - Sb);
}
size_t get_empty_block(size_t)
size_t get_empty_block(size_t, bool)
{
return _empty_blocks.back();
}
......
......@@ -281,7 +281,7 @@ public:
return dS;
}
size_t get_empty_block(size_t)
size_t get_empty_block(size_t, bool)
{
return _empty_blocks.back();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment