Commit 68abf8f3 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

blockmodel: Refactor mcmc loop

parent 667a4926
......@@ -483,10 +483,8 @@ public:
// move a vertex from its current block to block nr
template <class GetB>
void move_vertex(size_t v, size_t nr, GetB&& get_b)
void move_vertex(size_t v, size_t r, size_t nr, GetB&& get_b)
{
size_t r = get_b(v);
if (r == nr)
return;
......@@ -515,6 +513,13 @@ public:
assert(size_t(get_b(v)) == nr);
}
template <class GetB>
void move_vertex(size_t v, size_t nr, GetB&& get_b)
{
size_t r = get_b(v);
move_vertex(v, r, nr, get_b);
}
void move_vertex(size_t v, size_t nr)
{
move_vertex(v, nr, [&](auto u) -> auto& { return this->_b[u]; });
......
......@@ -61,8 +61,16 @@ python::object mcmc_layered_sweep(python::object omcmc_state,
(omcmc_state,
[&](auto& s)
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
if (s._parallel)
{
auto ret_ = mcmc_sweep_parallel(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
else
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
});
},
false);
......
......@@ -61,8 +61,16 @@ python::object mcmc_layered_overlap_sweep(python::object omcmc_state,
(omcmc_state,
[&](auto& s)
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
if (s._parallel)
{
auto ret_ = mcmc_sweep_parallel(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
else
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
});
},
false);
......
......@@ -48,8 +48,16 @@ python::object do_mcmc_sweep(python::object omcmc_state,
(omcmc_state,
[&](auto& s)
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
if (s._parallel)
{
auto ret_ = mcmc_sweep_parallel(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
else
{
auto ret_ = mcmc_sweep(s, rng);
ret = python::make_tuple(ret_.first, ret_.second);
}
});
};
block_state::dispatch(oblock_state, dispatch);
......
......@@ -104,21 +104,22 @@ struct MCMC
return s;
}
std::pair<double, double> virtual_move_dS(size_t v, size_t nr)
std::tuple<double, double, double>
virtual_move_dS(size_t v, size_t nr)
{
double dS = _state.virtual_move(v, _state._b[v], nr, _entropy_args,
size_t r = _state._b[v];
double dS = _state.virtual_move(v, r, nr, _entropy_args,
_m_entries);
double a = 0;
if (!std::isinf(_c))
{
size_t r = _state._b[v];
double pf = _state.get_move_prob(v, r, nr, _c, false,
_m_entries);
double pb = _state.get_move_prob(v, nr, r, _c, true,
_m_entries);
a = log(pb) - log(pf);
}
return make_pair(dS, a);
return make_tuple(dS, a, dS);
}
void perform_move(size_t v, size_t nr)
......
......@@ -129,7 +129,8 @@ struct MCMC
return s;
}
std::pair<double, double> virtual_move_dS(size_t i, size_t nr)
std::tuple<double, double, double>
virtual_move_dS(size_t i, size_t nr)
{
double dS = 0;
......@@ -160,7 +161,7 @@ struct MCMC
for (auto v : _bundles[i])
_state.move_vertex(v, r);
return make_pair(dS, a);
return std::make_tuple(dS, a, dS);
}
void perform_move(size_t i, size_t nr)
......
......@@ -34,20 +34,86 @@
namespace graph_tool
{
template <class RNG>
bool metropolis_accept(double dS, double mP, double beta, RNG& rng)
{
if (std::isinf(beta))
{
return dS < 0;
}
else
{
double a = -dS * beta + mP;
if (a > 0)
{
return true;
}
else
{
typedef std::uniform_real_distribution<> rdist_t;
double sample = rdist_t()(rng);
return sample < exp(a);
}
}
}
template <class MCMCState, class RNG>
auto mcmc_sweep(MCMCState state, RNG& rng_)
auto mcmc_sweep(MCMCState state, RNG& rng)
{
auto& vlist = state._vlist;
auto& beta = state._beta;
double S = 0;
size_t nmoves = 0;
for (size_t iter = 0; iter < state._niter; ++iter)
{
std::shuffle(vlist.begin(), vlist.end(), rng);
for (auto v : vlist)
{
if (!state._sequential)
v = uniform_sample(vlist, rng);
if (state.node_weight(v) == 0)
continue;
auto r = state.node_state(v);
auto s = state.move_proposal(v, rng);
if (s == r)
continue;
double dS, mP, rdS;
std::tie(dS, mP, rdS) = state.virtual_move_dS(v, s);
if (metropolis_accept(dS, mP, beta, rng))
{
state.perform_move(v, s);
nmoves += state.node_weight(v);
S += rdS;
}
if (state._verbose)
cout << v << ": " << r << " -> " << s << " " << S << endl;
}
}
return make_pair(S, nmoves);
}
template <class MCMCState, class RNG>
auto mcmc_sweep_parallel(MCMCState state, RNG& rng_)
{
auto& g = state._g;
vector<std::shared_ptr<RNG>> rngs;
std::vector<std::pair<size_t, double>> best_move;
if (state._parallel)
{
init_rngs(rngs, rng_);
init_cache(state._E);
best_move.resize(num_vertices(g));
}
init_rngs(rngs, rng_);
init_cache(state._E);
best_move.resize(num_vertices(g));
auto& vlist = state._vlist;
auto& beta = state._beta;
......@@ -58,31 +124,21 @@ auto mcmc_sweep(MCMCState state, RNG& rng_)
for (size_t iter = 0; iter < state._niter; ++iter)
{
if (!state._parallel)
{
std::shuffle(vlist.begin(), vlist.end(), rng_);
}
else
{
parallel_loop(vlist,
[&](size_t, auto v)
{
best_move[v] =
std::make_pair(state.node_state(v),
numeric_limits<double>::max());
});
}
#pragma omp parallel firstprivate(state) if (state._parallel)
parallel_loop(vlist,
[&](size_t, auto v)
{
best_move[v] =
std::make_pair(state.node_state(v),
numeric_limits<double>::max());
});
#pragma omp parallel firstprivate(state)
parallel_loop_no_spawn
(vlist,
[&](size_t, auto v)
{
auto& rng = get_rng(rngs, rng_);
if (!state._sequential)
v = uniform_sample(vlist, rng);
if (state.node_weight(v) == 0)
return;
......@@ -92,69 +148,40 @@ auto mcmc_sweep(MCMCState state, RNG& rng_)
if (s == r)
return;
std::pair<double, double> dS = state.virtual_move_dS(v, s);
double dS, mP, rdS;
std::tie(dS, mP, rdS) = state.virtual_move_dS(v, s);
bool accept = false;
if (std::isinf(beta))
if (metropolis_accept(dS, mP, beta, rng))
{
accept = dS.first < 0;
}
else
{
double a = -dS.first * beta + dS.second;
if (a > 0)
{
accept = true;
}
else
{
typedef std::uniform_real_distribution<> rdist_t;
double sample = rdist_t()(rng);
accept = sample < exp(a);
}
best_move[v].first = s;
best_move[v].second = dS;
}
if (accept)
{
if (!state._parallel)
{
state.perform_move(v, s);
nmoves += state.node_weight(v);
S += dS.first;
}
else
{
best_move[v].first = s;
best_move[v].second = dS.first;
}
if (state._verbose)
cout << v << ": " << r << " -> " << s << " " << S << endl;
}
if (state._verbose)
cout << v << ": " << r << " -> " << s << " " << S << endl;
});
if (state._parallel)
for (auto v : vlist)
{
for (auto v : vlist)
auto s = best_move[v].first;
double dS = best_move[v].second;
if (dS != numeric_limits<double>::max())
{
auto s = best_move[v].first;
double dS = best_move[v].second;
if (dS != numeric_limits<double>::max())
{
dS = state.virtual_move_dS(v, s).first;
if (dS > 0 && std::isinf(beta))
continue;
state.perform_move(v, s);
nmoves++;
S += dS;
}
auto ddS = state.virtual_move_dS(v, s);
if (get<0>(ddS) > 0 && std::isinf(beta))
continue;
state.perform_move(v, s);
nmoves++;
S += get<2>(ddS);
}
}
}
return make_pair(S, nmoves);
}
} // graph_tool namespace
#endif //MCMC_LOOP_HH
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment