Commit 98d213e0 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: simplify and improve EGroups update performance

parent 18caad84
......@@ -74,7 +74,7 @@ typedef mpl::vector1<std::false_type> rmap_tr;
#define BLOCK_STATE_params \
((g, &, all_graph_views, 1)) \
((is_weighted,, mpl::vector1<std::true_type>, 1)) \
((is_weighted,, mpl::vector1<std::true_type>, 1)) \
((use_hash,, bool_tr, 1)) \
((use_rmap,, rmap_tr, 1)) \
((_abg, &, boost::any&, 0)) \
......@@ -417,9 +417,6 @@ public:
_wr[r] -= _vweight[v];
if (!_egroups.empty() && _egroups_update)
_egroups.remove_vertex(v, _b, _eweight, _g);
if (is_partition_stats_enabled())
get_partition_stats(v).remove_vertex(v, r, _deg_corr, _g,
_vweight, _eweight,
......@@ -432,9 +429,6 @@ public:
_wr[r] += _vweight[v];
if (!_egroups.empty() && _egroups_update)
_egroups.add_vertex(v, _b, _eweight, _g);
if (is_partition_stats_enabled())
get_partition_stats(v).add_vertex(v, r, _deg_corr, _g, _vweight,
_eweight, _degs);
......@@ -2353,6 +2347,14 @@ public:
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool relax)
{
_egroups.check(_bg, _mrs);
_egroups_update = !relax;
if (_coupled_state != nullptr)
_coupled_state->relax_update(relax);
}
//private:
typedef typename
std::conditional<is_directed_::apply<g_t>::type::value,
......
......@@ -43,6 +43,8 @@ public:
insert_edge(source(e, bg), target(e, bg), mrs[e]);
insert_edge(target(e, bg), source(e, bg), mrs[e]);
}
check(bg, mrs);
}
void add_block()
......@@ -97,39 +99,6 @@ public:
}
}
template <bool Add, class Vertex, class Eprop, class Vprop, class Graph>
void modify_vertex(Vertex v, Vprop& b, Eprop& eweight, Graph& g)
{
auto iter_edges = [&](auto&& range)
{
for (auto e : range)
{
auto ew = (Add) ? eweight[e] : -eweight[e];
auto s = b[source(e, g)];
auto t = b[target(e, g)];
insert_edge(s, t, ew);
if (source(e, g) != target(e, g))
insert_edge(t, s, ew);
}
};
iter_edges(out_edges_range(v, g));
if constexpr (is_directed_::apply<Graph>::type::value)
iter_edges(in_edges_range(v, g));
}
template <class Vertex, class Vprop, class Eprop, class Graph>
void add_vertex(Vertex v, Vprop& b, Eprop& eweight, Graph& g)
{
modify_vertex<true>(v, b, eweight, g);
}
template <class Vertex, class Vprop, class Eprop, class Graph>
void remove_vertex(Vertex v, Vprop& b, Eprop& eweight, Graph& g)
{
modify_vertex<false>(v, b, eweight, g);
}
template <class RNG>
size_t sample_edge(size_t r, RNG& rng)
{
......@@ -138,6 +107,39 @@ public:
return s;
}
template <class Eprop, class BGraph>
void check([[maybe_unused]] BGraph& bg, [[maybe_unused]] Eprop& mrs)
{
#ifndef NDEBUG
if (empty() || true)
return;
for (auto e : edges_range(bg))
{
auto r = source(e, bg);
auto s = target(e, bg);
auto& pos = _pos[r];
auto iter = pos.find(s);
assert(iter != pos.end());
auto p = _egroups[r].get_prob(iter->second);
if (!graph_tool::is_directed(bg) || r == s)
{
assert(p == mrs[e] * (r == s ? 2 : 1));
}
else
{
auto ne = edge(s, r, bg);
if (ne.second)
assert(p == mrs[e] + mrs[ne.first]);
else
assert(p == mrs[e]);
}
}
#endif
}
private:
vector<DynamicSampler<size_t>> _egroups;
vector<gt_hash_map<size_t, size_t>> _pos;
......
......@@ -545,6 +545,8 @@ void apply_delta(State& state, MEntries& m_entries)
auto eops =
[&](auto&& eop, auto&& mid_op, auto&& end_op, auto&& skip)
{
bool update_egroups = !state._egroups.empty() && state._egroups_update;
eop(m_entries, state._emat,
[&](auto r, auto s, auto& me, auto delta, auto&... edelta)
{
......@@ -571,6 +573,19 @@ void apply_delta(State& state, MEntries& m_entries)
state._mrp[r] += delta;
state._mrm[s] += delta;
if (update_egroups)
{
if (r != s)
{
state._egroups.insert_edge(r, s, delta);
state._egroups.insert_edge(s, r, delta);
}
else
{
state._egroups.insert_edge(r, s, 2 * delta);
}
}
assert(state._mrs[me] >= 0);
assert(state._mrp[r] >= 0);
assert(state._mrm[s] >= 0);
......
......@@ -198,7 +198,7 @@ struct MCMC
void relax_update(bool relax)
{
_state._egroups_update = !relax;
_state.relax_update(relax);
}
void store_next_state(size_t v)
......@@ -208,6 +208,7 @@ struct MCMC
void clear_next_state()
{
_state._egroups.check(_state._bg, _state._mrs);
_state.clear_next_state();
}
......
......@@ -221,7 +221,7 @@ struct MCMC
void relax_update(bool relax)
{
_state._egroups_update = !relax;
_state.relax_update(relax);
}
template <class V>
......
......@@ -71,6 +71,7 @@ public:
virtual vprop_map_t<int32_t>::type::unchecked_t& get_pclabel() = 0;
virtual bool check_edge_counts(bool emat=true) = 0;
virtual bool allow_move(size_t r, size_t nr) = 0;
virtual void relax_update(bool relax) = 0;
};
} // graph_tool namespace
......
......@@ -1005,6 +1005,12 @@ struct Layers
BaseState::clear_egroups();
}
virtual void relax_update(bool relax)
{
BaseState::relax_update(relax);
}
vprop_map_t<int32_t>::type::unchecked_t& get_b()
{
return BaseState::_b;
......
......@@ -125,8 +125,6 @@ public:
typedef modularity_entropy_args_t _entropy_args_t;
bool _egroups_update = true;
// =========================================================================
// State modification
// =========================================================================
......@@ -345,7 +343,7 @@ public:
void pop_state() {}
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool) {}
};
} // graph_tool namespace
......
......@@ -191,14 +191,10 @@ public:
{
_overlap_stats.add_half_edge(v, r, _b, _g);
_b[v] = r;
if (!_egroups.empty() && _egroups_update)
_egroups.add_vertex(v, _b, _eweight, _g);
}
else
{
_overlap_stats.remove_half_edge(v, r, _b, _g);
if (!_egroups.empty() && _egroups_update)
_egroups.remove_vertex(v, _b, _eweight, _g);
}
_wr[r] = _overlap_stats.get_block_size(r);
......@@ -1272,6 +1268,12 @@ public:
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool relax)
{
_egroups.check(_bg, _mrs);
_egroups_update = !relax;
}
//private:
typedef typename
std::conditional<is_directed_::apply<g_t>::type::value,
......
......@@ -113,8 +113,6 @@ public:
typedef int m_entries_t;
bool _egroups_update = true;
// =========================================================================
// State modification
// =========================================================================
......@@ -298,6 +296,7 @@ public:
void pop_state() {}
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool) {}
};
......
......@@ -162,8 +162,6 @@ public:
typedef int m_entries_t;
bool _egroups_update = true;
// =========================================================================
// State modification
// =========================================================================
......@@ -340,6 +338,7 @@ public:
void pop_state() {}
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool) {}
};
......
......@@ -144,8 +144,6 @@ public:
typedef int m_entries_t;
bool _egroups_update = true;
typedef char _entropy_args_t;
PartitionModeState& get_mode(size_t r)
......@@ -253,6 +251,8 @@ public:
_next_list.clear();
}
void relax_update(bool) {}
size_t virtual_remove_size(size_t v)
{
return _wr[_b[v]] - 1;
......
......@@ -138,8 +138,6 @@ public:
typedef pp_entropy_args_t _entropy_args_t;
bool _egroups_update = true;
// =========================================================================
// State modification
// =========================================================================
......@@ -501,6 +499,7 @@ public:
void pop_state() {}
void store_next_state(size_t) {}
void clear_next_state() {}
void relax_update(bool) {}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment