Commit f44e89c4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: Implement parallel sweeps for TemperingState with NestedBlockState

parent 5369eb34
......@@ -27,6 +27,7 @@ from .. import Vector_size_t, Vector_double
import numpy
from . util import *
from . nested_blockmodel import NestedBlockState
def mcmc_equilibrate(state, wait=1000, nbreaks=2, max_niter=numpy.inf,
force_niter=None, epsilon=0, gibbs=False, multiflip=False,
......@@ -673,6 +674,14 @@ class TemperingState(object):
"""Perform a full sweep of the parallel states, where state moves are
attempted by calling `sweep_algo(state, beta=beta, **kwargs)`."""
algo_states = []
if isinstance(self.states[0], NestedBlockState):
ls = list(kwargs.pop("ls", range(len(self.states[0].levels))))
if kwargs.pop("ls_shuffle", True):
numpy.random.shuffle(ls)
kwargs["ls"] = ls
kwargs["ls_shuffle"] = False
for state, beta in zip(self.states, self.betas):
entropy_args = dict(kwargs.get("entropy_args", {}))
algo_state = sweep_algo[0](state,
......
......@@ -673,7 +673,7 @@ class NestedBlockState(object):
state._couple_state(None, None)
return state
def _h_sweep(self, algo, **kwargs):
def _h_sweep_gen(self, **kwargs):
if not self.sampling:
raise ValueError("NestedBlockState must be constructed with 'sampling=True'")
......@@ -687,14 +687,11 @@ class NestedBlockState(object):
recs=False)
self.levels[l]._couple_state(self.levels[l + 1], eargs)
dS = 0
nattempts = 0
nmoves = 0
c = kwargs.get("c", None)
lrange = list(kwargs.pop("ls", range(len(self.levels))))
numpy.random.shuffle(lrange)
if kwargs.pop("ls_shuffle", True):
numpy.random.shuffle(lrange)
for l in lrange:
if check_verbose(verbose):
print(verbose_pad(verbose) + "level:", l)
......@@ -745,6 +742,17 @@ class NestedBlockState(object):
if len(rs) > 0:
reverse_map(rs, self.levels[l].empty_pos)
yield l, self.levels[l], args
def _h_sweep(self, algo, **kwargs):
entropy_args = kwargs.get("entropy_args", {})
dS = 0
nattempts = 0
nmoves = 0
for l, lstate, args in self._h_sweep_gen(**kwargs):
ret = algo(self.levels[l], **args)
if l > 0 and "beta_dl" in entropy_args:
......@@ -756,6 +764,30 @@ class NestedBlockState(object):
return dS, nattempts, nmoves
def _h_sweep_states(self, algo, **kwargs):
entropy_args = kwargs.get("entropy_args", {})
for l, lstate, args in self._h_sweep_gen(**kwargs):
if l > 0 and "beta_dl" in entropy_args:
yield l, lstate, algo(self.levels[l], dispatch=False, **args), entropy_args["beta_dl"]
else:
yield l, lstate, algo(self.levels[l], dispatch=False, **args), 1
def _h_sweep_parallel_dispatch(states, sweeps, algo):
ret = None
for lsweep in zip(*sweeps):
ls = [x[0] for x in lsweep]
lstates = [x[1] for x in lsweep]
lsweep_states = [x[2] for x in lsweep]
beta_dl = [x[3] for x in lsweep]
lret = algo(type(lstates[0]), lstates, lsweep_states)
if ret is None:
ret = lret
else:
ret = [(ret[i][0] + lret[i][0] * beta_dl[i],
ret[i][1] + lret[i][1],
ret[i][2] + lret[i][2]) for i in range(len(lret))]
return ret
def mcmc_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a Metropolis-Hastings acceptance-rejection
MCMC to sample hierarchical network partitions.
......@@ -773,20 +805,28 @@ class NestedBlockState(object):
if not isinstance(c, collections.Iterable):
c = [c] + [c * 2 ** l for l in range(1, len(self.levels))]
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
if kwargs.pop("dispatch", True):
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.mcmc_sweep(**a),
c=c, **kwargs)
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.mcmc_sweep(**a),
c=c, **kwargs)
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
else:
return self._h_sweep_states(lambda s, **a: s.mcmc_sweep(**a),
c=c, **kwargs)
def _mcmc_sweep_parallel_dispatch(states, sweeps):
algo = lambda s, lstates, lsweep_states: s._mcmc_sweep_parallel_dispatch(lstates, lsweep_states)
return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)
def multiflip_mcmc_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a Metropolis-Hastings acceptance-rejection MCMC
......@@ -806,19 +846,27 @@ class NestedBlockState(object):
if not isinstance(c, collections.Iterable):
c = [c] + [c * 2 ** l for l in range(1, len(self.levels))]
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
if kwargs.pop("dispatch", True):
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.multiflip_mcmc_sweep(**a),
c=c, **kwargs)
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.multiflip_mcmc_sweep(**a),
c=c, **kwargs)
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
else:
return self._h_sweep_states(lambda s, **a: s.multiflip_mcmc_sweep(**a),
c=c, **kwargs)
def _multiflip_mcmc_sweep_parallel_dispatch(states, sweeps):
algo = lambda s, lstates, lsweep_states: s._multiflip_mcmc_sweep_parallel_dispatch(lstates, lsweep_states)
return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)
def gibbs_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a rejection-free Gibbs MCMC to sample network
......@@ -827,19 +875,27 @@ class NestedBlockState(object):
The arguments accepted are the same as in
:method:`graph_tool.inference.BlockState.gibbs_sweep`.
"""
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
if kwargs.pop("dispatch", True):
if _bm_test():
kwargs = dict(kwargs, test=False)
entropy_args = kwargs.get("entropy_args", {})
Si = self.entropy(**entropy_args)
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.gibbs_sweep(**a))
dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.gibbs_sweep(**a))
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
if _bm_test():
Sf = self.entropy(**entropy_args)
assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
str(entropy_args))
return dS, nattempts, nmoves
else:
return self._h_sweep_states(lambda s, **a: s.gibbs_sweep(**a),
**kwargs)
def _gibbs_sweep_parallel_dispatch(states, sweeps):
algo = lambda s, lstates, lsweep_states: s._gibbs_sweep_parallel_dispatch(lstates, lsweep_states)
return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)
def multicanonical_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a non-Markovian multicanonical sampling using the
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment