Commit 550ab951 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: optimize multiflip_mcmc() by disabling uncessary edge list updates

parent bf0e2475
......@@ -218,7 +218,7 @@ public:
_eweight(uncheck(__aeweight, typename std::add_pointer<eweight_t>::type())),
_degs(uncheck(__adegs, typename std::add_pointer<degs_t>::type())),
_emat(other._emat),
_egroups_enabled(other._egroups_enabled),
_egroups_update(other._egroups_update),
_neighbor_sampler(other._neighbor_sampler),
_m_entries(num_vertices(_bg))
{
......@@ -313,7 +313,7 @@ public:
// move a vertex from its current block to block nr
void move_vertex(size_t v, size_t r, size_t nr)
{
move_vertex(v, r, nr, [](auto&) {return false;});
move_vertex(v, r, nr, [](auto&) constexpr {return false;});
}
void move_vertex(size_t v, size_t nr)
......@@ -444,7 +444,7 @@ public:
_wr[r] -= _vweight[v];
if (!_egroups.empty() && _egroups_enabled)
if (!_egroups.empty() && _egroups_update)
_egroups.remove_vertex(v, _b, _g);
if (is_partition_stats_enabled())
......@@ -459,7 +459,7 @@ public:
_wr[r] += _vweight[v];
if (!_egroups.empty() && _egroups_enabled)
if (!_egroups.empty() && _egroups_update)
_egroups.add_vertex(v, _b, _eweight, _g);
if (is_partition_stats_enabled())
......@@ -1356,6 +1356,7 @@ public:
template <class MEntries>
[[gnu::hot]]
double virtual_move(size_t v, size_t r, size_t nr, const entropy_args_t& ea,
MEntries& m_entries)
{
......@@ -1364,7 +1365,7 @@ public:
if (r != null_group && nr != null_group && !allow_move(r, nr))
return std::numeric_limits<double>::infinity();
get_move_entries(v, r, nr, m_entries, [](auto) { return false; });
get_move_entries(v, r, nr, m_entries, [](auto) constexpr { return false; });
if (r == nr || _vweight[v] == 0)
return 0;
......@@ -1651,6 +1652,7 @@ public:
return dS;
}
[[gnu::hot]]
double virtual_move(size_t v, size_t r, size_t nr, const entropy_args_t& ea)
{
return virtual_move(v, r, nr, ea, _m_entries);
......@@ -2625,7 +2627,7 @@ public:
emat_t _emat;
EGroups<g_t, is_weighted_t> _egroups;
bool _egroups_enabled = true;
bool _egroups_update = true;
typedef NeighborSampler<g_t, is_weighted_t, mpl::false_>
neighbor_sampler_t;
......
......@@ -215,6 +215,7 @@ public:
}
template <class Vertex, class VProp>
[[gnu::hot]]
void remove_vertex(Vertex v, VProp& b, Graph& g)
{
if (_egroups.empty())
......@@ -230,6 +231,7 @@ public:
}
template <class Vertex, class Vprop, class Eprop>
[[gnu::hot]]
void add_vertex(Vertex v, Vprop& b, Eprop& eweight, Graph& g)
{
if (_egroups.empty())
......
......@@ -95,7 +95,10 @@ struct MCMC
{
if (!_allow_vacate && _state.is_last(v))
return _null_move;
return _state.sample_block(v, _c, _d, rng);
size_t s = _state.sample_block(v, _c, _d, rng);
if (s == node_state(v))
return _null_move;
return s;
}
std::tuple<double, double>
......
......@@ -648,6 +648,8 @@ struct MCMC
if (_groups[r].size() < 2)
return {_null_move, 1};
_state._egroups_update = false;
_vs = _groups[r];
push_b(_vs);
......@@ -668,6 +670,8 @@ struct MCMC
for (auto v : _vs)
_bnext[v] = _state._b[v];
pop_b();
_state._egroups_update = true;
}
break;
......@@ -681,6 +685,8 @@ struct MCMC
if (!allow_merge(r, s))
return {_null_move, 1};
_state._egroups_update = false;
if (!std::isinf(_beta))
{
pf = merge_prob(r, s) + log(_pmerge);
......@@ -696,6 +702,8 @@ struct MCMC
_bnext[v] = _state._b[v];
pop_b();
_state._egroups_update = true;
if (_verbose)
cout << "merge proposal: " << _groups[r].size() << " "
<< _groups[s].size() << " " << _dS << " " << pb - pf
......@@ -708,6 +716,8 @@ struct MCMC
if (_rlist.size() == 1)
return {_null_move, 1};
_state._egroups_update = false;
push_b(_groups[r]);
auto ret = sample_merge(r, rng);
......@@ -717,6 +727,7 @@ struct MCMC
{
while (!_bstack.empty())
pop_b();
_state._egroups_update = true;
return {_null_move, 1};
}
......@@ -742,6 +753,8 @@ struct MCMC
while (!_bstack.empty())
pop_b();
_state._egroups_update = true;
if (_verbose)
cout << "mergesplit proposal: " << _dS << " " << pb - pf
<< " " << -_dS + pb - pf << endl;
......
......@@ -88,7 +88,7 @@ public:
_bg(boost::any_cast<std::reference_wrapper<bg_t>>(__abg)),
_c_mrs(_mrs.get_checked()),
_emat(_bg, rng),
_egroups_enabled(true),
_egroups_update(true),
_overlap_stats(_g, _b, _half_edges, _node_index, num_vertices(_bg)),
_coupled_state(nullptr)
{
......@@ -166,7 +166,7 @@ public:
_B_E_D(other._B_E_D),
_rt(other._rt),
_emat(other._emat),
_egroups_enabled(other._egroups_enabled),
_egroups_update(other._egroups_update),
_overlap_stats(other._overlap_stats),
_coupled_state(nullptr)
{
......@@ -194,13 +194,13 @@ public:
{
_overlap_stats.add_half_edge(v, r, _b, _g);
_b[v] = r;
if (!_egroups.empty() && _egroups_enabled)
if (!_egroups.empty() && _egroups_update)
_egroups.add_vertex(v, _b, _eweight, _g);
}
else
{
_overlap_stats.remove_half_edge(v, r, _b, _g);
if (!_egroups.empty() && _egroups_enabled)
if (!_egroups.empty() && _egroups_update)
_egroups.remove_vertex(v, _b, _g);
}
......@@ -1290,7 +1290,7 @@ public:
emat_t _emat;
EGroups<g_t, mpl::false_> _egroups;
bool _egroups_enabled;
bool _egroups_update;
overlap_stats_t _overlap_stats;
std::vector<overlap_partition_stats_t> _partition_stats;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment