Commit b694522e authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: refactor multiflip MCMC

parent 080938d8
......@@ -39,16 +39,20 @@ using namespace std;
((state, &, State&, 0)) \
((beta,, double, 0)) \
((c,, double, 0)) \
((a1,, double, 0)) \
((d,, double, 0)) \
((prec,, double, 0)) \
((psingle,, double, 0)) \
((psplit,, double, 0)) \
((pmerge,, double, 0)) \
((pmergesplit,, double, 0)) \
((nproposal, &, vector<size_t>&, 0)) \
((nacceptance, &, vector<size_t>&, 0)) \
((gibbs_sweeps,, size_t, 0)) \
((entropy_args,, entropy_args_t, 0)) \
((verbose,, int, 0)) \
((force_move,, bool, 0)) \
((niter,, size_t, 0))
enum class move_t { single_node = 0, split, merge, recombine, null };
enum class move_t { single = 0, split, merge, mergesplit, null };
template <class State>
struct MCMC
......@@ -75,13 +79,9 @@ struct MCMC
num_vertices(_state._g)),
_rpos(get(vertex_index_t(), _state._bg),
num_vertices(_state._bg)),
_btemp(get(vertex_index_t(), _state._g),
num_vertices(_state._g)),
_btemp2(get(vertex_index_t(), _state._g),
num_vertices(_state._g)),
_bprev(get(vertex_index_t(), _state._g),
num_vertices(_state._g)),
_bnext(get(vertex_index_t(), _state._g),
num_vertices(_state._g)),
_btemp(get(vertex_index_t(), _state._g),
num_vertices(_state._g))
{
_state.init_mcmc(_c,
......@@ -102,6 +102,13 @@ struct MCMC
continue;
add_element(_rlist, _rpos, r);
}
std::vector<move_t> moves
= {move_t::single, move_t::split, move_t::merge,
move_t::mergesplit};
std::vector<double> probs
= {_psingle, _psplit, _pmerge, _pmergesplit};
_move_sampler = Sampler<move_t, mpl::false_>(moves, probs);
}
typename state_t::g_t& _g;
......@@ -109,25 +116,54 @@ struct MCMC
std::vector<std::vector<size_t>> _groups;
typename vprop_map_t<size_t>::type::unchecked_t _vpos;
typename vprop_map_t<size_t>::type::unchecked_t _rpos;
size_t _nmoves = 0;
typename vprop_map_t<int>::type::unchecked_t _btemp;
typename vprop_map_t<int>::type::unchecked_t _btemp2;
typename vprop_map_t<int>::type::unchecked_t _bprev;
typename vprop_map_t<int>::type::unchecked_t _bnext;
std::vector<std::vector<std::tuple<size_t, size_t>>> _bstack;
Sampler<move_t, mpl::false_> _move_sampler;
void _push_b_dispatch() {}
template <class... Vs>
void _push_b_dispatch(const std::vector<size_t>& vs, Vs&&... vvs)
{
auto& back = _bstack.back();
for (auto v : vs)
back.emplace_back(v, _state._b[v]);
_push_b_dispatch(std::forward<Vs>(vvs)...);
}
template <class... Vs>
void push_b(Vs&&... vvs)
{
_bstack.emplace_back();
_push_b_dispatch(std::forward<Vs>(vvs)...);
}
void pop_b()
{
auto& back = _bstack.back();
for (auto& vb : back)
{
size_t v = get<0>(vb);
size_t s = get<1>(vb);
move_vertex(v, s);
}
_bstack.pop_back();
}
std::vector<size_t> _rlist;
std::vector<size_t> _vs;
typename vprop_map_t<int>::type::unchecked_t _bnext;
typename vprop_map_t<int>::type::unchecked_t _btemp;
constexpr static move_t _null_move = move_t::null;
size_t _N = 0;
double _dS;
double _a;
size_t _s = null_group;
size_t _t = null_group;
size_t _u = null_group;
size_t _v = null_group;
size_t node_state(size_t r)
{
......@@ -167,37 +203,79 @@ struct MCMC
remove_element(_groups[s], _vpos, v);
_state.move_vertex(v, r);
add_element(_groups[r], _vpos, v);
_nmoves++;
}
template <class RNG, bool forward=true>
std::tuple<size_t, size_t, double, double> split(size_t t, size_t r,
size_t s, RNG& rng)
template <class RNG>
std::tuple<double, double>
gibbs_sweep(std::vector<size_t>& vs, size_t r, size_t s,
double beta, RNG& rng)
{
if (forward)
_vs = _groups[t];
std::shuffle(_vs.begin(), _vs.end(), rng);
double lp = 0, dS = 0;
std::array<double,2> p = {0,0};
std::shuffle(vs.begin(), vs.end(), rng);
for (auto v : vs)
{
size_t bv = _state._b[v];
size_t nbv = (bv == r) ? s : r;
double ddS;
if (_state.virtual_remove_size(v) > 0)
ddS = _state.virtual_move(v, bv, nbv, _entropy_args);
else
ddS = std::numeric_limits<double>::infinity();
if (!std::isinf(beta) && !std::isinf(ddS))
{
double Z = log_sum(0., -ddS * beta);
p[0] = -ddS * beta - Z;
p[1] = -Z;
}
else
{
if (ddS < 0)
{
p[0] = 0;
p[1] = -std::numeric_limits<double>::infinity();
}
else
{
p[0] = -std::numeric_limits<double>::infinity();;
p[1] = 0;
}
}
std::bernoulli_distribution sample(exp(p[0]));
if (sample(rng))
{
move_vertex(v, nbv);
lp += p[0];
dS += ddS;
}
else
{
lp += p[1];
}
}
return {dS, lp};
}
template <class RNG, bool forward=true>
std::tuple<double, size_t, size_t>
stage_split(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
{
std::array<size_t, 2> rt = {null_group, null_group};
std::array<double, 2> ps;
double lp = -log(2);
double dS = 0;
for (auto v : _vs)
std::shuffle(vs.begin(), vs.end(), rng);
for (auto v : vs)
{
if constexpr (!forward)
_btemp[v] = _state._b[v];
if (rt[0] == null_group)
{
if constexpr (forward)
rt[0] = (r == null_group) ? sample_new_group(v, rng) : r;
else
rt[0] = r;
rt[0] = r;
dS += _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
if constexpr (forward)
move_vertex(v, rt[0]);
else
_state.move_vertex(v, rt[0]);
move_vertex(v, rt[0]);
continue;
}
......@@ -209,120 +287,86 @@ struct MCMC
rt[1] = s;
dS += _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);
if (forward)
move_vertex(v, rt[1]);
else
_state.move_vertex(v, rt[1]);
move_vertex(v, rt[1]);
continue;
}
ps[0] = -_state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
ps[1] = -_state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);
double Z = 0, p0 = 0;
if (!std::isinf(_beta))
{
Z = log_sum(_beta * ps[0], _beta * ps[1]);
p0 = _beta * ps[0] - Z;
}
else
{
p0 = (ps[0] < ps[1]) ? 0 : -numeric_limits<double>::infinity();
}
ps[0] = ps[1] = 0;
double Z = log_sum(ps[0], ps[1]);
double p0 = _beta * ps[0] - Z;
std::bernoulli_distribution sample(exp(p0));
if (sample(rng))
{
if constexpr (forward)
move_vertex(v, rt[0]);
else
_state.move_vertex(v, rt[0]);
lp += p0;
dS -= ps[0];
dS += _state.virtual_move(v, _state._b[v], rt[0],
_entropy_args);
move_vertex(v, rt[0]);
}
else
{
if constexpr (forward)
move_vertex(v, rt[1]);
else
_state.move_vertex(v, rt[1]);
if (!std::isinf(_beta))
lp += _beta * ps[1] - Z;
dS -= ps[1];
dS += _state.virtual_move(v, _state._b[v], rt[1],
_entropy_args);
move_vertex(v, rt[1]);
}
}
return {dS, rt[0], rt[1]};
}
// gibbs sweep
for (size_t i = 0; i < (forward ? _gibbs_sweeps : _gibbs_sweeps - 1); ++i)
{
lp = 0;
std::array<double,2> p = {0,0};
for (auto v : _vs)
{
size_t bv = _state._b[v];
size_t nbv = (bv == rt[0]) ? rt[1] : rt[0];
double ddS;
if (_state.virtual_remove_size(v) > 0)
ddS = _state.virtual_move(v, bv, nbv, _entropy_args);
else
ddS = std::numeric_limits<double>::infinity();
if (!std::isinf(_beta) && !std::isinf(ddS))
{
double Z = log_sum(0., -ddS * _beta);
p[0] = -ddS * _beta - Z;
p[1] = -Z;
}
else
{
if (ddS < 0)
{
p[0] = 0;
p[1] = -std::numeric_limits<double>::infinity();
}
else
{
p[0] = -std::numeric_limits<double>::infinity();;
p[1] = 0;
}
}
template <class RNG, bool forward=true>
std::tuple<size_t, double, double> split(size_t r, size_t s,
RNG& rng)
{
auto vs = _groups[r];
std::bernoulli_distribution sample(exp(p[0]));
if (sample(rng))
{
if constexpr (forward)
move_vertex(v, nbv);
else
_state.move_vertex(v, nbv);
lp += p[0];
dS += ddS;
}
else
{
lp += p[1];
}
}
if constexpr (!forward)
vs.insert(vs.end(), _groups[s].begin(), _groups[s].end());
double dS;
std::array<size_t, 2> rt;
std::tie(dS, rt[0], rt[1]) = stage_split(vs, r, s, rng);
for (size_t i = 0; i < _gibbs_sweeps - 1; ++i)
{
auto ret = gibbs_sweep(vs, rt[0], rt[1],
(i < _gibbs_sweeps / 2) ? 1 : _beta,
rng);
dS += get<0>(ret);
}
return {rt[0], rt[1], dS, lp};
double lp = 0;
if constexpr (forward)
{
auto ret = gibbs_sweep(vs, rt[0], rt[1], _beta, rng);
dS += get<0>(ret);
lp = get<1>(ret);
}
return {rt[1], dS, lp};
}
template <class RNG>
double split_prob(size_t t, size_t r, size_t s, RNG& rng)
double split_prob(size_t r, size_t s, RNG& rng)
{
split<RNG, false>(t, r, s, rng);
auto vs = _groups[r];
vs.insert(vs.end(), _groups[s].begin(), _groups[s].end());
for (auto v : _vs)
_btemp2[v] = _state._b[v];
push_b(vs);
for (auto v : vs)
_btemp[v] = _state._b[v];
split<RNG, false>(r, s, rng);
std::shuffle(vs.begin(), vs.end(), rng);
double lp1 = 0, lp2 = 0;
for (bool swap : std::array<bool,2>({false, true}))
{
if (swap)
std::swap(r, s);
for (auto v : _vs)
if (!swap)
push_b(vs);
for (auto v : vs)
{
size_t bv = _state._b[v];
size_t nbv = (bv == r) ? s : r;
......@@ -337,10 +381,11 @@ struct MCMC
double Z = log_sum(0., -ddS);
size_t tbv = _btemp[v];
double p;
if ((size_t(_bprev[v]) == r) == (nbv == r))
if ((swap) ? tbv != nbv : tbv == nbv)
{
_state.move_vertex(v, nbv);
move_vertex(v, nbv);
p = -ddS - Z;
}
else
......@@ -355,50 +400,45 @@ struct MCMC
}
if (!swap)
{
for (auto v : _vs)
_state.move_vertex(v, _btemp2[v]);
}
pop_b();
}
for (auto v : _vs)
_state.move_vertex(v, _btemp[v]);
pop_b();
return log_sum(lp1, lp2) - log(2);;
return log_sum(lp1, lp2) - log(2);
}
bool allow_merge(size_t r, size_t s)
{
// if (_state._coupled_state != nullptr)
// {
// auto& bh = _state._coupled_state->get_b();
// if (bh[r] != bh[s])
// return false;
// }
//return _state._bclabel[r] == _state._bclabel[s];
return _state.allow_move(r, s);
}
template <class RNG>
std::tuple<size_t, double> merge(size_t r, size_t s, size_t t, RNG& rng)
double merge(size_t r, size_t s)
{
double dS = 0;
_vs = _groups[r];
_vs.insert(_vs.end(), _groups[s].begin(), _groups[s].end());
if (t == null_group)
t = sample_new_group(_vs.front(), rng);
auto vs = _groups[r];
for (auto v : _vs)
for (auto v : vs)
{
size_t bv = _bprev[v] = _state._b[v];
dS +=_state.virtual_move(v, bv, t, _entropy_args);
move_vertex(v, t);
size_t bv = _state._b[v];
dS +=_state.virtual_move(v, bv, s, _entropy_args);
move_vertex(v, s);
}
return {t, dS};
return dS;
}
template <class RNG>
size_t sample_move(size_t r, RNG& rng)
{
auto s = r;
while (s == r)
{
size_t v = uniform_sample(_groups[r], rng);
s = _state.sample_block(v, _c, 0, rng);
}
return s;
}
double get_move_prob(size_t r, size_t s)
......@@ -416,158 +456,222 @@ struct MCMC
double merge_prob(size_t r, size_t s)
{
double pr = get_move_prob(r, s);
double ps = get_move_prob(s, r);
return log(pr + ps) - log(2);
return pr;
return log(get_move_prob(r, s));
}
template <class RNG>
std::tuple<size_t, double, double, double>
sample_merge(size_t r, RNG& rng)
{
size_t s = sample_move(r, rng);
if (!allow_merge(r, s))
return {null_group, 0., 0., 0.};
double pf = 0, pb = 0;
if (!std::isinf(_beta))
{
pf = merge_prob(r, s);
pb = split_prob(s, r, rng);
}
if (_verbose)
cout << "merge " << _groups[r].size() << " " << _groups[s].size();
double dS = merge(r, s);
if (_verbose)
cout << " " << dS << " " << pf << " " << pb << endl;
return {s, dS, pf, pb};
}
template <class RNG>
std::tuple<size_t, double, double, double>
sample_split(size_t r, size_t s, RNG& rng)
{
double dS, pf, pb=0;
std::tie(s, dS, pf) = split(r, s, rng);
if (!std::isinf(_beta))
pb = merge_prob(s, r);
if (_verbose)
cout << "split " << _groups[r].size() << " " << _groups[s].size()
<< " " << dS << " " << pf << " " << pb << endl;
return {s, dS, pf, pb};
}
template <class RNG>
std::tuple<move_t,size_t> move_proposal(size_t r, RNG& rng)
{
move_t move;
double pf = 0, pb = 0;
_dS = _a = 0;
_vs.clear();
_nmoves = 0;
std::bernoulli_distribution single(_a1);
if (single(rng))
{
move = move_t::single_node;
auto v = uniform_sample(_groups[r], rng);
_s = _state.sample_block(v, _c, _d, rng);
if (_s >= _groups.size())
{
_groups.resize(_s + 1);
_rpos.resize(_s + 1);
}
if (r == _s || !_state.allow_move(r, _s))
return {_null_move, 1};
if (_d == 0 && _groups[r].size() == 1 && !std::isinf(_beta))
return {_null_move, 1};
_dS = _state.virtual_move(v, r, _s, _entropy_args);
if (!std::isinf(_beta))
{
pf = log(_state.get_move_prob(v, r, _s, _c, _d, false));
pb = log(_state.get_move_prob(v, _s, r, _c, _d, true));
pf += -safelog_fast(_rlist.size());
pf += -safelog_fast(_groups[r].size());
int dB = 0;
if (_groups[_s].empty())
dB++;
if (_groups[r].size() == 1)
dB--;
pb += -safelog_fast(_rlist.size() + dB);
pb += -safelog_fast(_groups[_s].size() + 1);
}
_vs.clear();
_vs.push_back(v);
_bprev[v] = r;
_bnext[v] = _s;
}