Commit a0399483 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Implement DrawBlockState base state

parent 38e6fab9
......@@ -72,7 +72,7 @@ from .. import Graph, GraphView, _check_prop_vector, _check_prop_scalar, \
from .. topology import max_cardinality_matching, max_independent_vertex_set, \
label_components, shortest_distance, make_maximal_planar, is_planar
from .. generation import predecessor_tree, condensation_graph
from .. inference import nested_contiguous_map
from .. inference.util import nested_contiguous_map
import numpy.random
from numpy import sqrt
......
......@@ -18,7 +18,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from .. import Vector_size_t
from .. import Vector_size_t, group_vector_property
from . util import *
......@@ -28,6 +28,8 @@ import gc
import numpy
import graph_tool.draw
__test__ = False
def set_test(test):
......@@ -690,3 +692,22 @@ class ExhaustiveSweepState(ABC):
return (Ss, density[2]), b_min
else:
return b_min
class DrawBlockState(ABC):
r"""Base state that implements group-based drawing."""
def draw(self, **kwargs):
r"""Convenience wrapper to :func:`~graph_tool.draw.graph_draw` that
draws the state of the graph as colors on the vertices and edges."""
gradient = self.g.new_ep("double")
gradient = group_vector_property([gradient])
return graph_tool.draw.graph_draw(self.g,
vertex_fill_color=kwargs.get("vertex_fill_color",
self.b),
vertex_color=kwargs.get("vertex_color", self.b),
edge_gradient=kwargs.get("edge_gradient",
gradient),
**dmask(kwargs, ["vertex_fill_color",
"vertex_color",
"edge_gradient"]))
......@@ -117,7 +117,8 @@ def init_q_cache(max_n=None):
libinference.init_q_cache(min(_q_cache_max_n, max_n))
class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
GibbsMCMCState, MulticanonicalMCMCState, ExhaustiveSweepState):
GibbsMCMCState, MulticanonicalMCMCState, ExhaustiveSweepState,
DrawBlockState):
r"""The stochastic block model state of a given graph.
Parameters
......@@ -1530,22 +1531,6 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
h, update, unlabel)
return h
def draw(self, **kwargs):
r"""Convenience wrapper to :func:`~graph_tool.draw.graph_draw` that
draws the state of the graph as colors on the vertices and edges."""
gradient = self.g.new_ep("double")
gradient = group_vector_property([gradient])
from graph_tool.draw import graph_draw
return graph_draw(self.g,
vertex_fill_color=kwargs.get("vertex_fill_color",
self.b),
vertex_color=kwargs.get("vertex_color", self.b),
edge_gradient=kwargs.get("edge_gradient",
gradient),
**dmask(kwargs, ["vertex_fill_color",
"vertex_color",
"edge_gradient"]))
def sample_graph(self, canonical=False, multigraph=True, self_loops=True,
sample_params=False, max_ent=False, n_iter=1000):
r"""Sample a new graph from the fitted model.
......
......@@ -90,7 +90,7 @@ def modularity(g, b, gamma=1., weight=None):
class ModularityState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
GibbsMCMCState):
GibbsMCMCState, DrawBlockState):
r"""Obtain the partition of a network according to Newman's modularity.
.. warning::
......@@ -229,22 +229,6 @@ class ModularityState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
Q = self.entropy(gamma=gamma)
return -Q / (2 * self.g.num_edges())
def draw(self, **kwargs):
r"""Convenience wrapper to :func:`~graph_tool.draw.graph_draw` that
draws the state of the graph as colors on the vertices and edges."""
gradient = self.g.new_ep("double")
gradient = group_vector_property([gradient])
from graph_tool.draw import graph_draw
return graph_draw(self.g,
vertex_fill_color=kwargs.get("vertex_fill_color",
self.b),
vertex_color=kwargs.get("vertex_color", self.b),
edge_gradient=kwargs.get("edge_gradient",
gradient),
**dmask(kwargs, ["vertex_fill_color",
"vertex_color",
"edge_gradient"]))
def _mcmc_sweep_dispatch(self, mcmc_state):
return libinference.modularity_mcmc_sweep(mcmc_state, self._state,
_get_rng())
......
......@@ -31,7 +31,7 @@ import numpy as np
import math
class PPBlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
GibbsMCMCState):
GibbsMCMCState, DrawBlockState):
r"""Obtain the partition of a network according to the Bayesian planted partition
model.
......@@ -258,22 +258,6 @@ class PPBlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
str(list(kwargs.keys())))
return ea
def draw(self, **kwargs):
r"""Convenience wrapper to :func:`~graph_tool.draw.graph_draw` that
draws the state of the graph as colors on the vertices and edges."""
gradient = self.g.new_ep("double")
gradient = group_vector_property([gradient])
from graph_tool.draw import graph_draw
return graph_draw(self.g,
vertex_fill_color=kwargs.get("vertex_fill_color",
self.b),
vertex_color=kwargs.get("vertex_color", self.b),
edge_gradient=kwargs.get("edge_gradient",
gradient),
**dmask(kwargs, ["vertex_fill_color",
"vertex_color",
"edge_gradient"]))
def _mcmc_sweep_dispatch(self, mcmc_state):
return libinference.pp_mcmc_sweep(mcmc_state, self._state,
_get_rng())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment