Commit da0ea999 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: modularize neighbour sampling code

parent e2e46593
......@@ -120,7 +120,7 @@ public:
_free.push_back(pos);
}
void reset()
void clear()
{
_items.clear();
_tree.clear();
......@@ -142,7 +142,7 @@ public:
probs.push_back(_tree[i]);
}
reset();
clear();
for (size_t i = 0; i < items.size(); ++i)
insert(items[i], probs[i]);
......@@ -153,11 +153,21 @@ public:
return _items[i];
}
const auto& items() const
{
return _items;
}
size_t size() const
{
return _items.size();
}
bool empty() const
{
return _items.empty();
}
private:
void check_size(size_t i)
......
......@@ -147,6 +147,6 @@ BOOST_PYTHON_MODULE(libgraph_tool_generation)
return_value_policy<copy_const_reference>())
.def("insert", &DynamicSampler<int>::insert)
.def("remove", &DynamicSampler<int>::remove)
.def("reset", &DynamicSampler<int>::reset)
.def("clear", &DynamicSampler<int>::clear)
.def("rebuild", &DynamicSampler<int>::rebuild);
}
......@@ -68,6 +68,7 @@ libgraph_tool_inference_la_include_HEADERS = \
graph_blockmodel_overlap_util.hh \
graph_blockmodel_overlap_vacate.hh \
graph_blockmodel_util.hh \
graph_neighbour_sampler.hh \
graph_state.hh \
mcmc_loop.hh \
merge_loop.hh \
......
......@@ -111,12 +111,11 @@ public:
_vweight(uncheck(__avweight, typename std::add_pointer<vweight_t>::type())),
_eweight(uncheck(__aeweight, typename std::add_pointer<eweight_t>::type())),
_emat(_bg, rng),
_neighbour_sampler(get(vertex_index_t(), _g), num_vertices(_g)),
_neighbour_sampler(_g, _eweight),
_m_entries(num_vertices(_bg)),
_coupled_state(nullptr),
_gstate(this)
{
rebuild_neighbour_sampler();
_empty_blocks.clear();
_candidate_blocks.clear();
_candidate_blocks.push_back(null_group);
......@@ -1125,10 +1124,9 @@ public:
s = uniform_sample(_empty_blocks, rng);
}
auto& sampler = _neighbour_sampler[v];
if (!std::isinf(c) && !sampler.empty())
if (!std::isinf(c) && !_neighbour_sampler.empty(v))
{
auto u = sample_neighbour(sampler, rng);
auto u = _neighbour_sampler.sample(v, rng);
size_t t = _b[u];
double p_rand = 0;
if (c > 0)
......@@ -1163,10 +1161,9 @@ public:
size_t random_neighbour(size_t v, rng_t& rng)
{
auto& sampler = _neighbour_sampler[v];
if (sampler.empty())
if (_neighbour_sampler.empty(v))
return v;
return sample_neighbour(sampler, rng);
return _neighbour_sampler.sample(v, rng);
}
// Computes the move proposal probability
......@@ -1571,7 +1568,7 @@ public:
void rebuild_neighbour_sampler()
{
init_neighbour_sampler(_g, _eweight, _neighbour_sampler);
_neighbour_sampler = neighbour_sampler_t(_g, _eweight);
}
void sync_emat()
......@@ -1608,13 +1605,11 @@ public:
EGroups<g_t, is_weighted_t> _egroups;
typedef typename std::conditional<is_weighted_t::value,
typename property_map_type::apply<Sampler<size_t, mpl::false_>,
typename property_map<g_t, vertex_index_t>::type>::type,
typename property_map_type::apply<vector<size_t>,
typename property_map<g_t, vertex_index_t>::type>::type>::type::unchecked_t
sampler_map_t;
WeightedNeighbourSampler<g_t, DynamicSampler>,
UnweightedNeighbourSampler<g_t>>::type
neighbour_sampler_t;
sampler_map_t _neighbour_sampler;
neighbour_sampler_t _neighbour_sampler;
std::vector<partition_stats_t> _partition_stats;
std::vector<size_t> _bmap;
......
......@@ -26,6 +26,7 @@
#include "../generation/sampler.hh"
#include "../generation/dynamic_sampler.hh"
#include "graph_neighbour_sampler.hh"
#include "util.hh"
#include "int_part.hh"
......@@ -1459,73 +1460,6 @@ private:
epos_t _epos;
};
// Sample neighbours efficiently
// =============================
template <class Vertex, class Graph, class Eprop, class SMap>
void build_neighbour_sampler(Vertex v, SMap& sampler, Eprop& eweight, Graph& g,
bool self_loops=false)
{
vector<Vertex> neighbours;
vector<double> probs;
neighbours.reserve(total_degreeS()(v, g));
probs.reserve(total_degreeS()(v, g));
for (auto e : all_edges_range(v, g))
{
Vertex u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
auto w = eweight[e];
if (w == 0)
continue;
neighbours.push_back(u);
probs.push_back(w); // Self-loops will be added twice, and hence will
// be sampled with probability 2 * eweight[e]
}
sampler = Sampler<Vertex, mpl::false_>(neighbours, probs);
};
template <class Vertex, class Graph, class Eprop>
void build_neighbour_sampler(Vertex v, vector<size_t>& sampler, Eprop&, Graph& g,
bool self_loops=false)
{
sampler.clear();
for (auto e : all_edges_range(v, g))
{
Vertex u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
sampler.push_back(u); // Self-loops will be added twice
}
};
template <class Graph, class Eprop, class Sampler>
void init_neighbour_sampler(Graph& g, Eprop eweight, Sampler& sampler)
{
for (auto v : vertices_range(g))
build_neighbour_sampler(v, sampler[v], eweight, g);
}
template <class Sampler, class RNG>
auto sample_neighbour(Sampler& sampler, RNG& rng)
{
return sampler.sample(rng);
}
template <class Vertex, class RNG>
auto sample_neighbour(vector<Vertex>& sampler, RNG& rng)
{
return uniform_sample(sampler, rng);
}
// Sampling marginal probabilities on the edges
template <class Graph, class Vprop, class MEprop>
void collect_edge_marginals(size_t B, Vprop b, MEprop p, Graph& g, Graph&)
......
// graph-tool -- a general graph modification and manipulation thingy
//
// Copyright (C) 2006-2016 Tiago de Paula Peixoto <tiago@skewed.de>
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
#ifndef GRAPH_NEIGHBOUR_SAMPLER_HH
#define GRAPH_NEIGHBOUR_SAMPLER_HH
#include "config.h"
#include "graph_tool.hh"
// Sample neighbours efficiently
// =============================
namespace graph_tool
{
template <class Graph>
class UnweightedNeighbourSampler
{
public:
typedef typename boost::graph_traits<Graph>::vertex_descriptor vertex_t;
template <class Eprop>
UnweightedNeighbourSampler(Graph& g, Eprop&, bool self_loops=false)
: _sampler_c(get(vertex_index_t(), g)),
_sampler(_sampler_c.get_unchecked()),
_sampler_pos_c(get(vertex_index_t(), g)),
_sampler_pos(_sampler_pos_c.get_unchecked())
{
for (auto v : vertices_range(g))
{
auto& sampler = _sampler_c[v];
auto& sampler_pos = _sampler_pos_c[v];
sampler.clear();
sampler_pos.clear();
for (auto e : all_edges_range(v, g))
{
auto u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
sampler_pos[u].push_back(sampler.size());
sampler.push_back(u); // Self-loops will be added twice
}
}
}
template <class RNG>
vertex_t sample(vertex_t v, RNG& rng)
{
return uniform_sample(_sampler[v], rng);
}
bool empty(vertex_t v)
{
return _sampler[v].empty();
}
template <class Weight>
void remove(vertex_t v, vertex_t u, Weight)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
auto& pos_u = sampler_pos[u];
size_t i = pos_u.back();
sampler[i] = sampler.back();
for (auto& j : sampler_pos[sampler.back()])
{
if (j == sampler.size() - 1)
{
j = i;
break;
}
}
sampler.pop_back();
pos_u.pop_back();
if (pos_u.empty())
sampler_pos.erase(u);
}
template <class Weight>
void insert(vertex_t v, vertex_t u, Weight)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
sampler_pos[u].push_back(sampler.size());
sampler.push_back(u);
}
private:
typedef typename vprop_map_t<std::vector<vertex_t>>::type sampler_t;
sampler_t _sampler_c;
typename sampler_t::unchecked_t _sampler;
typedef typename vprop_map_t<gt_hash_map<vertex_t, vector<size_t>>>::type sampler_pos_t;
sampler_pos_t _sampler_pos_c;
typename sampler_pos_t::unchecked_t _sampler_pos;
};
template <class Graph, template <class V> class Sampler>
class WeightedNeighbourSampler
{
public:
typedef typename boost::graph_traits<Graph>::vertex_descriptor vertex_t;
template <class Eprop>
WeightedNeighbourSampler(Graph& g, Eprop& eweight, bool self_loops=false)
: _sampler_c(get(vertex_index_t(), g)),
_sampler(_sampler_c.get_unchecked()),
_sampler_pos_c(get(vertex_index_t(), g)),
_sampler_pos(_sampler_pos_c.get_unchecked())
{
for (auto v : vertices_range(g))
{
auto& sampler = _sampler_c[v];
auto& sampler_pos = _sampler_pos_c[v];
sampler.clear();
sampler_pos.clear();
for (auto e : all_edges_range(v, g))
{
auto u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
auto w = eweight[e];
if (w == 0)
continue;
insert(v, u, w); // Self-loops will be added twice, and hence will
// be sampled with probability 2 * eweight[e]
}
}
}
template <class RNG>
vertex_t sample(vertex_t v, RNG& rng)
{
return _sampler[v].sample(rng);
}
bool empty(vertex_t v)
{
return _sampler[v].empty();
}
template <class Weight>
void remove(vertex_t v, vertex_t u, Weight w)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
auto& pos = sampler_pos[u];
sampler.remove(pos.first);
w -= pos.second;
if (w > 0)
insert(v, u, w);
else
sampler_pos.erase(u);
}
template <class Weight>
void insert(vertex_t v, vertex_t u, Weight w)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
auto pos = sampler_pos.find(u);
if (pos != sampler_pos.end())
{
auto old_w = pos->second.second;
remove(v, u, old_w);
w += old_w;
}
sampler_pos[u] = std::make_pair(sampler.insert(u, w), w);
}
private:
typedef typename vprop_map_t<Sampler<vertex_t>>::type sampler_t;
sampler_t _sampler_c;
typename sampler_t::unchecked_t _sampler;
typedef typename vprop_map_t<gt_hash_map<vertex_t, pair<size_t, double>>>::type sampler_pos_t;
sampler_pos_t _sampler_pos_c;
typename sampler_pos_t::unchecked_t _sampler_pos;
};
}
#endif // GRAPH_NEIGHBOUR_SAMPLER_HH
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment