Commit f4234551 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Unify all random number generators into a single module-wise instance

The object may be globally reset via the seed_rng() function.
parent 2e0b08d3
...@@ -19,20 +19,14 @@ ...@@ -19,20 +19,14 @@
#include "graph.hh" #include "graph.hh"
#include "graph_properties.hh" #include "graph_properties.hh"
#include <boost/graph/random_spanning_tree.hpp> #include "random.hh"
#if (GCC_VERSION >= 40400) #include <boost/graph/random_spanning_tree.hpp>
# include <tr1/random>
#else
# include <boost/tr1/random.hpp>
#endif
using namespace std; using namespace std;
using namespace boost; using namespace boost;
using namespace graph_tool; using namespace graph_tool;
typedef tr1::mt19937 rng_t;
struct get_random_span_tree struct get_random_span_tree
{ {
template <class Graph, class IndexMap, class WeightMap, class TreeMap, template <class Graph, class IndexMap, class WeightMap, class TreeMap,
...@@ -89,10 +83,8 @@ typedef property_map_types::apply<mpl::vector<uint8_t>, ...@@ -89,10 +83,8 @@ typedef property_map_types::apply<mpl::vector<uint8_t>,
void get_random_spanning_tree(GraphInterface& gi, size_t root, void get_random_spanning_tree(GraphInterface& gi, size_t root,
boost::any weight_map, boost::any tree_map, boost::any weight_map, boost::any tree_map,
size_t seed) rng_t& rng)
{ {
rng_t rng(static_cast<rng_t::result_type>(seed));
typedef ConstantPropertyMap<size_t,GraphInterface::edge_t> cweight_t; typedef ConstantPropertyMap<size_t,GraphInterface::edge_t> cweight_t;
if (weight_map.empty()) if (weight_map.empty())
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "graph.hh" #include "graph.hh"
#include "graph_filtering.hh" #include "graph_filtering.hh"
#include "random.hh"
#include <graph_subgraph_isomorphism.hh> #include <graph_subgraph_isomorphism.hh>
#include <graph_python_interface.hh> #include <graph_python_interface.hh>
...@@ -65,16 +67,15 @@ struct get_subgraphs ...@@ -65,16 +67,15 @@ struct get_subgraphs
VertexLabel vertex_label1, boost::any vertex_label2, VertexLabel vertex_label1, boost::any vertex_label2,
EdgeLabel edge_label1, boost::any edge_label2, EdgeLabel edge_label1, boost::any edge_label2,
vector<vector<pair<size_t, size_t> > >& F, vector<vector<pair<size_t, size_t> > >& F,
vector<size_t>& vlist, pair<size_t,size_t> sn) const vector<size_t>& vlist, pair<reference_wrapper<rng_t>,size_t> sn) const
{ {
typedef PropLabelling<Graph1,Graph2,VertexLabel,VertexLabel> typedef PropLabelling<Graph1,Graph2,VertexLabel,VertexLabel>
vlabelling_t; vlabelling_t;
typedef PropLabelling<Graph1,Graph2,EdgeLabel,EdgeLabel> typedef PropLabelling<Graph1,Graph2,EdgeLabel,EdgeLabel>
elabelling_t; elabelling_t;
size_t seed = sn.first; rng_t& rng = sn.first;
size_t max_n = sn.second; size_t max_n = sn.second;
rng_t rng(static_cast<rng_t::result_type>(seed));
vlist.resize(num_vertices(*g)); vlist.resize(num_vertices(*g));
int i, N = num_vertices(*g); int i, N = num_vertices(*g);
for (i = 0; i < N; ++i) for (i = 0; i < N; ++i)
...@@ -166,7 +167,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2, ...@@ -166,7 +167,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2,
boost::any vertex_label1, boost::any vertex_label2, boost::any vertex_label1, boost::any vertex_label2,
boost::any edge_label1, boost::any edge_label2, boost::any edge_label1, boost::any edge_label2,
python::list vmapping, python::list emapping, python::list vmapping, python::list emapping,
size_t max_n, size_t seed) size_t max_n, rng_t& rng)
{ {
if (gi1.GetDirected() != gi2.GetDirected()) if (gi1.GetDirected() != gi2.GetDirected())
return; return;
...@@ -191,7 +192,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2, ...@@ -191,7 +192,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2,
run_action<graph_tool::detail::always_directed>() run_action<graph_tool::detail::always_directed>()
(gi1, bind<void>(get_subgraphs(), (gi1, bind<void>(get_subgraphs(),
_1, _2, _3, vertex_label2, _4, edge_label2, _1, _2, _3, vertex_label2, _4, edge_label2,
ref(F), ref(vlist), make_pair(seed, max_n)), ref(F), ref(vlist), make_pair(ref(rng), max_n)),
directed_graph_view_pointers(), vertex_props_t(), directed_graph_view_pointers(), vertex_props_t(),
edge_props_t()) edge_props_t())
(gi2.GetGraphView(), vertex_label1, edge_label1); (gi2.GetGraphView(), vertex_label1, edge_label1);
...@@ -201,7 +202,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2, ...@@ -201,7 +202,7 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2,
run_action<graph_tool::detail::never_directed>() run_action<graph_tool::detail::never_directed>()
(gi1, bind<void>(get_subgraphs(), (gi1, bind<void>(get_subgraphs(),
_1, _2, _3, vertex_label2, _4, edge_label2, _1, _2, _3, vertex_label2, _4, edge_label2,
ref(F), ref(vlist), make_pair(seed, max_n)), ref(F), ref(vlist), make_pair(ref(rng), max_n)),
undirected_graph_view_pointers(), vertex_props_t(), undirected_graph_view_pointers(), vertex_props_t(),
edge_props_t()) edge_props_t())
(gi2.GetGraphView(), vertex_label1, edge_label1); (gi2.GetGraphView(), vertex_label1, edge_label1);
......
...@@ -22,17 +22,14 @@ ...@@ -22,17 +22,14 @@
#include <utility> #include <utility>
#if (GCC_VERSION >= 40400) #if (GCC_VERSION >= 40400)
# include <tr1/unordered_set> # include <tr1/unordered_set>
# include <tr1/random>
#else #else
# include <boost/tr1/unordered_set.hpp> # include <boost/tr1/unordered_set.hpp>
# include <boost/tr1/random.hpp>
#endif #endif
namespace boost namespace boost
{ {
using namespace std; using namespace std;
typedef tr1::mt19937 rng_t;
namespace detail { namespace detail {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <boost/python.hpp> #include <boost/python.hpp>
#include "graph.hh" #include "graph.hh"
#include "random.hh"
using namespace boost; using namespace boost;
using namespace boost::python; using namespace boost::python;
...@@ -37,14 +38,14 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2, ...@@ -37,14 +38,14 @@ void subgraph_isomorphism(GraphInterface& gi1, GraphInterface& gi2,
boost::any vertex_label1, boost::any vertex_label2, boost::any vertex_label1, boost::any vertex_label2,
boost::any edge_label1, boost::any edge_label2, boost::any edge_label1, boost::any edge_label2,
python::list vmapping, python::list emapping, python::list vmapping, python::list emapping,
size_t n_max, size_t seed); size_t n_max, rng_t& rng);
double reciprocity(GraphInterface& gi); double reciprocity(GraphInterface& gi);
size_t sequential_coloring(GraphInterface& gi, boost::any order, size_t sequential_coloring(GraphInterface& gi, boost::any order,
boost::any color); boost::any color);
bool is_bipartite(GraphInterface& gi, boost::any part_map); bool is_bipartite(GraphInterface& gi, boost::any part_map);
void get_random_spanning_tree(GraphInterface& gi, size_t root, void get_random_spanning_tree(GraphInterface& gi, size_t root,
boost::any weight_map, boost::any tree_map, boost::any weight_map, boost::any tree_map,
size_t seed); rng_t& rng);
vector<int32_t> get_tsp(GraphInterface& gi, size_t src, boost::any weight_map); vector<int32_t> get_tsp(GraphInterface& gi, size_t src, boost::any weight_map);
void export_components(); void export_components();
......
...@@ -112,7 +112,7 @@ __all__ = ["Graph", "GraphView", "Vertex", "Edge", "Vector_bool", ...@@ -112,7 +112,7 @@ __all__ = ["Graph", "GraphView", "Vertex", "Edge", "Vector_bool",
"Vector_int16_t", "Vector_int32_t", "Vector_int64_t", "Vector_double", "Vector_int16_t", "Vector_int32_t", "Vector_int64_t", "Vector_double",
"Vector_long_double", "Vector_string", "value_types", "load_graph", "Vector_long_double", "Vector_string", "value_types", "load_graph",
"PropertyMap", "group_vector_property", "ungroup_vector_property", "PropertyMap", "group_vector_property", "ungroup_vector_property",
"infect_vertex_property", "edge_difference", "show_config", "infect_vertex_property", "edge_difference", "seed_rng", "show_config",
"PropertyArray", "__author__", "__copyright__", "__URL__", "PropertyArray", "__author__", "__copyright__", "__URL__",
"__version__"] "__version__"]
...@@ -1964,3 +1964,15 @@ class GraphView(Graph): ...@@ -1964,3 +1964,15 @@ class GraphView(Graph):
return self.__base return self.__base
base = property(__get_base, doc="Base graph.") base = property(__get_base, doc="Base graph.")
_rng = libcore.get_rng(numpy.random.randint(0, sys.maxsize))
def seed_rng(seed):
"Seed the random number generator used by graph-tool's algorithms."
import graph_tool
graph_tool._rng = libcore.get_rng(int(seed))
def _get_rng():
global _rng
return _rng
...@@ -46,7 +46,7 @@ from __future__ import division, absolute_import, print_function ...@@ -46,7 +46,7 @@ from __future__ import division, absolute_import, print_function
from .. dl_import import dl_import from .. dl_import import dl_import
dl_import("from . import libgraph_tool_clustering as _gt") dl_import("from . import libgraph_tool_clustering as _gt")
from .. import _degree, _prop, Graph, GraphView from .. import _degree, _prop, Graph, GraphView, _get_rng
from .. topology import isomorphism from .. topology import isomorphism
from .. generation import random_rewire from .. generation import random_rewire
from .. stats import vertex_hist from .. stats import vertex_hist
...@@ -337,8 +337,6 @@ def motifs(g, k, p=1.0, motif_list=None): ...@@ -337,8 +337,6 @@ def motifs(g, k, p=1.0, motif_list=None):
:doi:`10.1109/TCBB.2006.51` :doi:`10.1109/TCBB.2006.51`
""" """
seed = random.randint(0, sys.maxsize)
sub_list = [] sub_list = []
directed_motifs = g.is_directed() directed_motifs = g.is_directed()
...@@ -364,7 +362,7 @@ def motifs(g, k, p=1.0, motif_list=None): ...@@ -364,7 +362,7 @@ def motifs(g, k, p=1.0, motif_list=None):
was_directed = g.is_directed() was_directed = g.is_directed()
_gt.get_motifs(g._Graph__graph, k, sub_list, hist, pd, _gt.get_motifs(g._Graph__graph, k, sub_list, hist, pd,
True, len(sub_list) == 0, True, len(sub_list) == 0,
seed) _get_rng())
# assemble graphs # assemble graphs
temp = [] temp = []
......
...@@ -66,7 +66,7 @@ Contents ...@@ -66,7 +66,7 @@ Contents
from __future__ import division, absolute_import, print_function from __future__ import division, absolute_import, print_function
from .. import GraphView, _check_prop_vector, group_vector_property, \ from .. import GraphView, _check_prop_vector, group_vector_property, \
ungroup_vector_property, infect_vertex_property, _prop ungroup_vector_property, infect_vertex_property, _prop, _get_rng
from .. topology import max_cardinality_matching, max_independent_vertex_set, \ from .. topology import max_cardinality_matching, max_independent_vertex_set, \
label_components, pseudo_diameter label_components, pseudo_diameter
from .. community import condensation_graph from .. community import condensation_graph
...@@ -338,7 +338,6 @@ def _coarse_graph(g, vweight, eweight, mivs=False, groups=None): ...@@ -338,7 +338,6 @@ def _coarse_graph(g, vweight, eweight, mivs=False, groups=None):
def _propagate_pos(g, cg, c, cc, cpos, delta, mivs): def _propagate_pos(g, cg, c, cc, cpos, delta, mivs):
seed = numpy.random.randint(sys.maxsize)
pos = g.new_vertex_property(cpos.value_type()) pos = g.new_vertex_property(cpos.value_type())
if mivs is not None: if mivs is not None:
...@@ -350,7 +349,7 @@ def _propagate_pos(g, cg, c, cc, cpos, delta, mivs): ...@@ -350,7 +349,7 @@ def _propagate_pos(g, cg, c, cc, cpos, delta, mivs):
_prop("v", g, pos), _prop("v", g, pos),
_prop("v", cg, cpos), _prop("v", cg, cpos),
delta if mivs is None else 0, delta if mivs is None else 0,
seed) _get_rng())
if mivs is not None: if mivs is not None:
g = g.base g = g.base
u = GraphView(g, directed=False) u = GraphView(g, directed=False)
...@@ -358,7 +357,7 @@ def _propagate_pos(g, cg, c, cc, cpos, delta, mivs): ...@@ -358,7 +357,7 @@ def _propagate_pos(g, cg, c, cc, cpos, delta, mivs):
libgraph_tool_layout.propagate_pos_mivs(u._Graph__graph, libgraph_tool_layout.propagate_pos_mivs(u._Graph__graph,
_prop("v", u, mivs), _prop("v", u, mivs),
_prop("v", u, pos), _prop("v", u, pos),
delta, seed) delta, _get_rng())
except ValueError: except ValueError:
graph_draw(u, mivs, vertex_fillcolor=mivs) graph_draw(u, mivs, vertex_fillcolor=mivs)
return pos return pos
......
...@@ -47,7 +47,7 @@ from __future__ import division, absolute_import, print_function ...@@ -47,7 +47,7 @@ from __future__ import division, absolute_import, print_function
from .. dl_import import dl_import from .. dl_import import dl_import
dl_import("from . import libgraph_tool_generation") dl_import("from . import libgraph_tool_generation")
from .. import Graph, GraphView, _check_prop_scalar, _prop, _limit_args, _gt_type from .. import Graph, GraphView, _check_prop_scalar, _prop, _limit_args, _gt_type, _get_rng
from .. stats import label_parallel_edges, label_self_loops from .. stats import label_parallel_edges, label_self_loops
import inspect import inspect
import types import types
...@@ -338,7 +338,6 @@ def random_graph(N, deg_sampler, deg_corr=None, cache_probs=True, directed=True, ...@@ -338,7 +338,6 @@ def random_graph(N, deg_sampler, deg_corr=None, cache_probs=True, directed=True,
no. 1: 016107 (2011) :doi:`10.1103/PhysRevE.83.016107` :arxiv:`1008.3926` no. 1: 016107 (2011) :doi:`10.1103/PhysRevE.83.016107` :arxiv:`1008.3926`
""" """
seed = numpy.random.randint(0, sys.maxsize)
g = Graph() g = Graph()
if deg_corr == None: if deg_corr == None:
uncorrelated = True uncorrelated = True
...@@ -370,7 +369,7 @@ def random_graph(N, deg_sampler, deg_corr=None, cache_probs=True, directed=True, ...@@ -370,7 +369,7 @@ def random_graph(N, deg_sampler, deg_corr=None, cache_probs=True, directed=True,
libgraph_tool_generation.gen_graph(g._Graph__graph, N, sampler, libgraph_tool_generation.gen_graph(g._Graph__graph, N, sampler,
uncorrelated, not parallel_edges, uncorrelated, not parallel_edges,
not self_loops, not directed, not self_loops, not directed,
seed, verbose, True) _get_rng(), verbose, True)
g.set_directed(directed) g.set_directed(directed)
if degree_block: if degree_block:
...@@ -670,8 +669,6 @@ def random_rewire(g, strat="uncorrelated", n_iter=1, edge_sweep=True, ...@@ -670,8 +669,6 @@ def random_rewire(g, strat="uncorrelated", n_iter=1, edge_sweep=True,
no. 1: 016107 (2011) :doi:`10.1103/PhysRevE.83.016107` :arxiv:`1008.3926` no. 1: 016107 (2011) :doi:`10.1103/PhysRevE.83.016107` :arxiv:`1008.3926`
""" """
seed = numpy.random.randint(0, sys.maxsize)
if not parallel_edges: if not parallel_edges:
p = label_parallel_edges(g) p = label_parallel_edges(g)
if p.a.max() != 0: if p.a.max() != 0:
...@@ -699,7 +696,7 @@ def random_rewire(g, strat="uncorrelated", n_iter=1, edge_sweep=True, ...@@ -699,7 +696,7 @@ def random_rewire(g, strat="uncorrelated", n_iter=1, edge_sweep=True,
self_loops, parallel_edges, self_loops, parallel_edges,
corr, _prop("v", g, blockmodel), corr, _prop("v", g, blockmodel),
cache_probs, cache_probs,
seed, verbose) _get_rng(), verbose)
if ret_fail: if ret_fail:
return pcount return pcount
...@@ -1223,7 +1220,5 @@ def price_network(N, m=1, c=None, gamma=1, directed=True, seed_graph=None): ...@@ -1223,7 +1220,5 @@ def price_network(N, m=1, c=None, gamma=1, directed=True, seed_graph=None):
N -= g.num_vertices() N -= g.num_vertices()
else: else:
g = seed_graph g = seed_graph
seed = numpy.random.randint(0, sys.maxsize) libgraph_tool_generation.price(g._Graph__graph, N, gamma, c, m, _get_rng())
libgraph_tool_generation.price(g._Graph__graph, N, gamma, c, m, seed)
return g return g
...@@ -49,7 +49,7 @@ from __future__ import division, absolute_import, print_function ...@@ -49,7 +49,7 @@ from __future__ import division, absolute_import, print_function
from .. dl_import import dl_import from .. dl_import import dl_import
dl_import("from . import libgraph_tool_stats") dl_import("from . import libgraph_tool_stats")
from .. import _degree, _prop from .. import _degree, _prop, _get_rng
from numpy import * from numpy import *
import numpy import numpy
import sys import sys
...@@ -391,14 +391,12 @@ def distance_histogram(g, weight=None, bins=[0, 1], samples=None, ...@@ -391,14 +391,12 @@ def distance_histogram(g, weight=None, bins=[0, 1], samples=None,
""" """
if samples != None: if samples != None:
seed = numpy.random.randint(0, sys.maxsize)
ret = libgraph_tool_stats.\ ret = libgraph_tool_stats.\
sampled_distance_histogram(g._Graph__graph, sampled_distance_histogram(g._Graph__graph,
_prop("e", g, weight), _prop("e", g, weight),
[float(x) for x in bins], [float(x) for x in bins],
samples, seed) samples, _get_rng())
else: else:
ret = libgraph_tool_stats.\ ret = libgraph_tool_stats.\
distance_histogram(g._Graph__graph, _prop("e", g, weight), bins) distance_histogram(g._Graph__graph, _prop("e", g, weight), bins)
return [array(ret[0], dtype="float64") if float_count else ret[0], ret[1]] return [array(ret[0], dtype="float64") if float_count else ret[0], ret[1]]
...@@ -65,7 +65,7 @@ from .. dl_import import dl_import ...@@ -65,7 +65,7 @@ from .. dl_import import dl_import
dl_import("from . import libgraph_tool_topology") dl_import("from . import libgraph_tool_topology")
from .. import _prop, Vector_int32_t, _check_prop_writable, \ from .. import _prop, Vector_int32_t, _check_prop_writable, \
_check_prop_scalar, _check_prop_vector, Graph, PropertyMap, GraphView _check_prop_scalar, _check_prop_vector, Graph, PropertyMap, GraphView, _get_rng
import random, sys, numpy import random, sys, numpy
__all__ = ["isomorphism", "subgraph_isomorphism", "mark_subgraph", __all__ = ["isomorphism", "subgraph_isomorphism", "mark_subgraph",
"max_cardinality_matching", "max_independent_vertex_set", "max_cardinality_matching", "max_independent_vertex_set",
...@@ -1542,11 +1542,10 @@ def max_independent_vertex_set(g, high_deg=False, mivs=None): ...@@ -1542,11 +1542,10 @@ def max_independent_vertex_set(g, high_deg=False, mivs=None):
_check_prop_scalar(mivs, "mivs") _check_prop_scalar(mivs, "mivs")
_check_prop_writable(mivs, "mivs") _check_prop_writable(mivs, "mivs")
seed = numpy.random.randint(0, sys.maxsize)
u = GraphView(g, directed=False) u = GraphView(g, directed=False)
libgraph_tool_topology.\ libgraph_tool_topology.\
maximal_vertex_set(u._Graph__graph, _prop("v", u, mivs), high_deg, maximal_vertex_set(u._Graph__graph, _prop("v", u, mivs), high_deg,
seed) _get_rng())
mivs = g.own_property(mivs) mivs = g.own_property(mivs)
return mivs return mivs
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment