Commit 822db1f2 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

blockmodel: simplify random neighbor sampling

parent 8477c7ef
......@@ -2603,7 +2603,7 @@ public:
EGroups<g_t, is_weighted_t> _egroups;
bool _egroups_enabled = true;
typedef NeighborSampler<g_t, is_weighted_t, is_weighted_t>
typedef NeighborSampler<g_t, is_weighted_t, mpl::false_>
neighbor_sampler_t;
neighbor_sampler_t _neighbor_sampler;
......
......@@ -39,7 +39,6 @@ public:
NeighborSampler(Graph& g, Eprop& eweight, bool self_loops=false)
: _g(g),
_sampler(get(vertex_index_t(), g), num_vertices(g)),
_sampler_pos(get(edge_index_t(), g)),
_self_loops(self_loops)
{
init(eweight);
......@@ -48,7 +47,14 @@ public:
template <class Eprop>
void init_vertex(size_t v, Eprop& eweight)
{
_sampler[v].clear();
init_vertex(v, eweight, Weighted());
}
template <class Eprop>
void init_vertex(size_t v, Eprop& eweight, std::true_type)
{
std::vector<item_t> us;
std::vector<double> ps;
for (auto e : out_edges_range(v, _g))
{
......@@ -67,17 +73,17 @@ public:
w /= 2;
}
insert(v, u, w, e);
us.push_back(u);
ps.push_back(w);
}
if constexpr (is_directed_::apply<Graph>::type::value)
{
for (auto e : in_edges_range(v, _g))
{
auto u = source(e, _g);
if (!_self_loops && u == v)
if (u == v)
continue;
auto w = eweight[e];
......@@ -85,7 +91,50 @@ public:
if (w == 0)
continue;
insert(v, u, w, e);
us.push_back(u);
ps.push_back(w);
}
}
_sampler[v] = sampler_t(us, ps);
}
template <class Eprop>
void init_vertex(size_t v, Eprop&, std::false_type)
{
auto& sampler = _sampler[v];
sampler.clear();
[[maybe_unused]] gt_hash_set<size_t> sl_set;
[[maybe_unused]] auto eindex = get(edge_index_t(), _g);
for (auto e : out_edges_range(v, _g))
{
auto u = target(e, _g);
if (u == v)
{
if (!_self_loops)
continue;
if constexpr (!is_directed_::apply<Graph>::type::value)
{
if (sl_set.find(eindex[e]) != sl_set.end())
continue;
sl_set.insert(eindex[e]);
}
}
sampler.push_back(u);
}
if constexpr (is_directed_::apply<Graph>::type::value)
{
for (auto e : in_edges_range(v, _g))
{
auto u = source(e, _g);
if (u == v)
continue;
sampler.push_back(u);
}
}
}
......@@ -101,8 +150,7 @@ public:
vertex_t sample(vertex_t v, RNG& rng)
{
auto& sampler = _sampler[v];
auto& item = sample_item(sampler, rng);
return get_u(item);
return sample_item(sampler, rng);
}
bool empty(vertex_t v)
......@@ -115,54 +163,8 @@ public:
_sampler.resize(n);
}
template <class Edge>
void remove(vertex_t v, vertex_t u, Edge&& e)
{
if (v == u && !_self_loops)
return;
auto& sampler = _sampler[v];
auto& pos = _sampler_pos[e];
bool is_src = (get_src(e) == u);
remove_item({is_src, e}, sampler, pos);
}
template <class Weight, class Edge>
void insert(vertex_t v, vertex_t u, Weight w, Edge&& e)
{
if (v == u && !_self_loops)
return;
auto& sampler = _sampler[v];
auto& pos = _sampler_pos[e];
bool is_src = (get_src(e) == u);
insert_item({is_src, e}, w, sampler, pos);
}
private:
typedef std::tuple<bool, edge_t> item_t;
vertex_t get_src(const edge_t& e)
{
if constexpr (is_directed_::apply<Graph>::type::value)
return source(e, _g);
else
return std::min(source(e, _g), target(e, _g));
}
vertex_t get_tgt(const edge_t& e)
{
if constexpr (is_directed_::apply<Graph>::type::value)
return target(e, _g);
else
return std::max(source(e, _g), target(e, _g));
}
vertex_t get_u(const item_t& item)
{
if (get<0>(item))
return get_src(get<1>(item));
else
return get_tgt(get<1>(item));
}
typedef vertex_t item_t;
template <class RNG>
const item_t& sample_item(std::vector<item_t>& sampler, RNG& rng)
......@@ -176,55 +178,6 @@ private:
return sampler.sample(rng);
}
size_t& get_pos(const item_t& u, std::tuple<size_t, size_t>& pos)
{
if (get<0>(u))
return get<0>(pos);
else
return get<1>(pos);
}
void remove_item(const item_t& u, std::vector<item_t>& sampler,
std::tuple<size_t, size_t>& pos)
{
auto u_pos = get_pos(u, pos);
if (u_pos >= sampler.size() || sampler[u_pos] != u)
return;
auto& back = sampler.back();
auto& e = get<1>(back);
auto& bpos = _sampler_pos[e];
get_pos(back, bpos) = u_pos;
sampler[u_pos] = back;
sampler.pop_back();
}
template <class Sampler>
void remove_item(const item_t& u, Sampler& sampler,
std::tuple<size_t, size_t>& pos)
{
auto i = get_pos(u, pos);
if (!sampler.is_valid(i) || sampler[i] != u)
return;
sampler.remove(i);
}
template <class Weight>
void insert_item(const item_t& u, Weight, std::vector<item_t>& sampler,
std::tuple<size_t, size_t>& pos)
{
get_pos(u, pos) = sampler.size();
sampler.push_back(u);
}
template <class Weight>
void insert_item(const item_t& u, Weight w, DynamicSampler<item_t>& sampler,
std::tuple<size_t, size_t>& pos)
{
get_pos(u, pos) = sampler.insert(u, w);
}
Graph& _g;
typedef typename std::conditional<Weighted::value,
......@@ -232,15 +185,12 @@ private:
DynamicSampler<item_t>,
Sampler<item_t,
boost::mpl::false_>>::type,
vector<item_t>>::type
std::vector<item_t>>::type
sampler_t;
typedef typename vprop_map_t<sampler_t>::type::unchecked_t vsampler_t;
vsampler_t _sampler;
typedef typename eprop_map_t<std::tuple<size_t, size_t>>::type pos_map_t;
pos_map_t _sampler_pos;
bool _self_loops;
};
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment