Commit de8556fc authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: add scatter stage split to multiflip_mcmc

parent 8918329f
......@@ -175,6 +175,7 @@ struct MCMC
return false;
}
template <bool sample_branch=true>
size_t sample_new_group(size_t v, rng_t& rng)
{
_state.get_empty_block(v);
......@@ -182,8 +183,21 @@ struct MCMC
auto r = _state._b[v];
if (_state._coupled_state != nullptr)
_state._coupled_state->sample_branch(t, r, rng);
{
if constexpr (sample_branch)
{
_state._coupled_state->sample_branch(t, r, rng);
}
else
{
auto& bh = _state._coupled_state->get_b();
bh[t] = bh[r];
}
auto& hpclabel = _state._coupled_state->get_pclabel();
hpclabel[t] = _state._pclabel[v];
}
_state._bclabel[t] = _state._bclabel[r];
if (t >= _groups.size())
{
_groups.resize(t + 1);
......@@ -258,9 +272,9 @@ struct MCMC
return {dS, lp};
}
template <class RNG, bool forward=true>
template <bool forward=true, class RNG>
std::tuple<double, size_t, size_t>
stage_split(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
stage_split_random(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
{
std::array<size_t, 2> rt = {null_group, null_group};
std::array<double, 2> ps;
......@@ -315,6 +329,88 @@ struct MCMC
return {dS, rt[0], rt[1]};
}
template <bool forward=true, class RNG>
std::tuple<double, size_t, size_t>
stage_split_scatter(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
{
std::array<size_t, 2> rt = {null_group, null_group};
std::array<double, 2> ps;
double dS = 0;
if (s != null_group && _groups[s].empty())
_state.move_vertex(_groups[r].front(), s);
size_t t;
if (_rlist.size() < (forward ? _N - 1 : _N))
t = sample_new_group<false>(_groups[r].front(), rng);
else
t = r;
if (s != null_group && _groups[s].empty())
_state.move_vertex(_groups[r].front(), r);
for (auto v : _groups[r])
{
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
}
if constexpr (!forward)
{
for (auto v : _groups[s])
{
dS += _state.virtual_move(v, _state._b[v], t,
_entropy_args);
move_vertex(v, t);
}
}
std::shuffle(vs.begin(), vs.end(), rng);
for (auto v : vs)
{
if (rt[0] == null_group)
{
rt[0] = r;
dS += _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
move_vertex(v, rt[0]);
continue;
}
if (rt[1] == null_group)
{
if constexpr (forward)
rt[1] = (s == null_group) ? sample_new_group(v, rng) : s;
else
rt[1] = s;
dS += _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);
move_vertex(v, rt[1]);
continue;
}
ps[0] = _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
ps[1] = _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);;
double Z = log_sum(ps[0], ps[1]);
double p0 = ps[0] - Z;
std::bernoulli_distribution sample(exp(p0));
if (sample(rng))
{
dS += ps[0];
move_vertex(v, rt[0]);
}
else
{
dS += ps[1];
move_vertex(v, rt[1]);
}
}
return {dS, rt[0], rt[1]};
}
template <class RNG, bool forward=true>
std::tuple<size_t, double, double> split(size_t r, size_t s,
......@@ -327,7 +423,12 @@ struct MCMC
double dS;
std::array<size_t, 2> rt;
std::tie(dS, rt[0], rt[1]) = stage_split(vs, r, s, rng);
std::bernoulli_distribution stage_sample(.5);
if (stage_sample(rng))
std::tie(dS, rt[0], rt[1]) = stage_split_random<forward>(vs, r, s, rng);
else
std::tie(dS, rt[0], rt[1]) = stage_split_scatter<forward>(vs, r, s, rng);
for (size_t i = 0; i < _gibbs_sweeps - 1; ++i)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment