Commit 67372fc8 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: improve multicanonical sampling

parent 3f0baa5b
......@@ -390,8 +390,8 @@ class MulticanonicalState(object):
Parameters
----------
g : :class:`~graph_tool.Graph`
Graph to be modelled.
state : :class:`~graph_tool.inference.Blockstate` or :class:`~graph_tool.inference.OverlapBlockstate` or :class:`~graph_tool.inference.NestedBlockstate`
Block state to be used.
S_min : ``float``
Minimum energy.
S_max : ``float``
......@@ -400,9 +400,10 @@ class MulticanonicalState(object):
Number of bins.
"""
def __init__(self, g, S_min, S_max, nbins=1000):
self._g = g
self._N = g.num_vertices()
def __init__(self, state, S_min, S_max, nbins=1000):
self._state = state
self._g = state.g
self._N = self._g.num_vertices()
self._S_min = S_min
self._S_max = S_max
self._density = Vector_double()
......@@ -413,18 +414,21 @@ class MulticanonicalState(object):
self._f = None
def __getstate__(self):
state = [self._g, self._S_min, self._S_max,
state = [self._state, self._S_min, self._S_max,
numpy.array(self._density.a), numpy.array(self._hist.a),
numpy.array(self._perm_hist), self._f]
return state
def __setstate__(self, state):
g, S_min, S_max, density, hist, phist, self._f = state
self.__init__(g, S_min, S_max, len(hist))
bstate, S_min, S_max, density, hist, phist, self._f = state
self.__init__(bstate, S_min, S_max, len(hist))
self._density.a[:] = density
self._hist.a[:] = hist
self._perm_hist[:] = phist
def sweep(self, **kwargs):
self._state.multicanonical_sweep(self, **kwargs)
def get_energies(self):
"Get energy bounds."
return self._S_min, self._S_max
......@@ -507,7 +511,7 @@ class MulticanonicalState(object):
self._perm_hist += self._hist.a
self._hist.a = 0
def multicanonical_equilibrate(state, m_state, f_range=(1., 1e-6), r=2,
def multicanonical_equilibrate(m_state, f_range=(1., 1e-6), r=2,
flatness=.95, allow_gaps=True, callback=None,
multicanonical_args={}, verbose=False):
r"""Equilibrate a multicanonical Monte Carlo sampling using the Wang-Landau
......@@ -515,8 +519,6 @@ def multicanonical_equilibrate(state, m_state, f_range=(1., 1e-6), r=2,
Parameters
----------
state : Any state class (e.g. :class:`~graph_tool.inference.blockmodel.BlockState`)
Initial state. This state will be modified during the algorithm.
m_state : :class:`~graph_tool.inference.mcmc.MulticanonicalState`
Initial multicanonical state, where the state density will be stored.
f_range : ``tuple`` of two floats (optional, default: ``(1., 1e-6)``)
......@@ -565,7 +567,7 @@ def multicanonical_equilibrate(state, m_state, f_range=(1., 1e-6), r=2,
if m_state._f is None:
m_state._f = f_range[0]
while m_state._f >= f_range[1]:
state.multicanonical_sweep(m_state, **multicanonical_args)
m_state.sweep(**multicanonical_args)
hf = m_state.get_flatness(allow_gaps=allow_gaps)
if callback is not None:
......@@ -575,8 +577,8 @@ def multicanonical_equilibrate(state, m_state, f_range=(1., 1e-6), r=2,
print(verbose_pad(verbose) +
"count: %d f: %#8.8g flatness: %#8.8g nonempty bins: %d S: %#8.8g B: %d" % \
(count, m_state._f, hf, (m_state._hist.a > 0).sum(),
state.entropy(**multicanonical_args.get("entropy_args", {})),
state.get_nonempty_B()))
m_state._state.entropy(**multicanonical_args.get("entropy_args", {})),
m_state._state.get_nonempty_B()))
if hf > flatness:
m_state._f /= r
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment