nested_blockmodel.py 52.5 KB
Newer Older
1
2
3
4
5
#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# graph_tool -- a general graph manipulation python module
#
Tiago Peixoto's avatar
Tiago Peixoto committed
6
# Copyright (C) 2006-2019 Tiago de Paula Peixoto <tiago@skewed.de>
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division, absolute_import, print_function
import sys
if sys.version_info < (3,):
    range = xrange

from .. import _degree, _prop, Graph, GraphView, conv_pickle_state
from . blockmodel import *
from . blockmodel import _bm_test
from . overlap_blockmodel import *
from . layered_blockmodel import *

from numpy import *
import numpy
import copy

class NestedBlockState(object):
    r"""The nested stochastic block model state of a given graph.

    Parameters
    ----------
    g : :class:`~graph_tool.Graph`
Tiago Peixoto's avatar
Tiago Peixoto committed
42
        Graph to be modeled.
43
    bs : ``list`` of :class:`~graph_tool.VertexPropertyMap` or :class:`numpy.ndarray`
44
        Hierarchical node partition.
45
    base_type : ``type`` (optional, default: :class:`~graph_tool.inference.blockmodel.BlockState`)
Tiago Peixoto's avatar
Tiago Peixoto committed
46
        State type for lowermost level
47
48
49
        (e.g. :class:`~graph_tool.inference.blockmodel.BlockState`,
        :class:`~graph_tool.inference.overlap_blockmodel.OverlapBlockState` or
        :class:`~graph_tool.inference.layered_blockmodel.LayeredBlockState`)
50
51
52
53
54
55
    hstate_args : ``dict`` (optional, default: `{}`)
        Keyword arguments to be passed to the constructor of the higher-level
        states.
    hentropy_args : ``dict`` (optional, default: `{}`)
        Keyword arguments to be passed to the ``entropy()`` method of the
        higher-level states.
56
    sampling : ``bool`` (optional, default: ``True``)
57
58
        If ``True``, the state will be properly prepared for MCMC sampling (as
        opposed to minimization).
59
    state_args : ``dict`` (optional, default: ``{}``)
60
        Keyword arguments to be passed to base type constructor.
61
62
63
    **kwargs :  keyword arguments
        Keyword arguments to be passed to base type constructor. The
        ``state_args`` parameter overrides this.
64
    """
65

66
    def __init__(self, g, bs, base_type=BlockState, state_args={},
67
                 hstate_args={}, hentropy_args={}, sampling=True, **kwargs):
68
        self.g = g
69
70
71
72
73
        self.base_type = base_type
        if base_type is LayeredBlockState:
            self.Lrecdx = []
        else:
            self.Lrecdx = libcore.Vector_double()
74
        self.state_args = dict(kwargs, **state_args)
75
        self.state_args["Lrecdx"] = self.Lrecdx
76
77
78
79
        if "rec_params" not in self.state_args:
            recs = self.state_args.get("recs", None)
            if recs is not None:
                self.state_args["rec_params"] = ["microcanonical"] * len(recs)
80
        self.hstate_args = dict(dict(deg_corr=False, vweight="nonempty"),
81
                                **hstate_args)
82
        self.hstate_args["Lrecdx"] = self.Lrecdx
83
        self.sampling = sampling
84
        if sampling:
85
86
87
88
89
90
91
92
93
            self.hstate_args = dict(self.hstate_args, copy_bg=False)
        self.hentropy_args = dict(hentropy_args,
                                  adjacency=True,
                                  dense=True,
                                  multigraph=True,
                                  dl=True,
                                  partition_dl=True,
                                  degree_dl=True,
                                  degree_dl_kind="distributed",
94
                                  edges_dl=False,
95
                                  exact=True,
96
                                  recs=True,
97
                                  recs_dl=False,
98
                                  beta_dl=1.)
99
        self.levels = [base_type(g, b=bs[0], **self.state_args)]
100
        for i, b in enumerate(bs[1:]):
101
            state = self.levels[-1]
102
103
104
105
            args = self.hstate_args
            if i == len(bs[1:]) - 1:
                args = dict(args, clabel=None, pclabel=None)
            bstate = state.get_block_state(b=b, **args)
106
107
            self.levels.append(bstate)

108
109
        self._regen_Lrecdx()

110
        if self.sampling:
111
            self._couple_levels(self.hentropy_args, None)
112

113
114
115
        if _bm_test():
            self._consistency_check()

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def _regen_Lrecdx(self, lstate=None):
        if lstate is None:
            levels = self.levels
            Lrecdx = self.Lrecdx
        else:
            levels = [s for s in self.levels]
            l, s = lstate
            levels[l] = s
            s = s.get_block_state(**dict(self.hstate_args,
                                         b=s.get_bclabel(),
                                         copy_bg=False))
            if l < len(levels) - 1:
                levels[l+1] = s
            else:
                levels.append(s)
            if self.base_type is LayeredBlockState:
                Lrecdx = [x.copy() for x in self.Lrecdx]
            else:
                Lrecdx = self.Lrecdx.copy()

        if self.base_type is not LayeredBlockState:
            Lrecdx.a = 0
            Lrecdx[0] = len([s for s in levels if s._state.get_B_E_D() > 0])
            for s in levels:
                Lrecdx.a[1:] += s.recdx.a * s._state.get_B_E_D()
                s.epsilon.a = levels[0].epsilon.a
            for s in levels:
                s.Lrecdx.a = Lrecdx.a
        else:
            Lrecdx[0].a = 0
            Lrecdx[0][0] = len([s for s in levels if s._state.get_B_E_D() > 0])
            for j in range(levels[0].C):
                Lrecdx[j+1].a = 0
                Lrecdx[j+1][0] = len([s for s in levels if s._state.get_layer(j).get_B_E_D() > 0])
            for s in levels:
                Lrecdx[0].a[1:] += s.recdx.a * s._state.get_B_E_D()
                s.epsilon.a = levels[0].epsilon.a
                for j in range(levels[0].C):
                    Lrecdx[j+1].a[1:] += s.layer_states[j].recdx.a * s._state.get_layer(j).get_B_E_D()
                    s.layer_states[j].epsilon.a = levels[0].epsilon.a

            for s in self.levels:
                for x, y in zip(s.Lrecdx, Lrecdx):
                    x.a = y.a

        if lstate is not None:
            return Lrecdx

164

165
166
167
    def _regen_levels(self):
        for l in range(1, len(self.levels)):
            state = self.levels[l]
168
169
            nstate = self.levels[l-1].get_block_state(b=state.b,
                                                      **self.hstate_args)
170
            self.levels[l] = nstate
171
        self._regen_Lrecdx()
172

173
174
175
    def __repr__(self):
        return "<NestedBlockState object, with base %s, and %d levels of sizes %s at 0x%x>" % \
            (repr(self.levels[0]), len(self.levels),
176
             str([(s.get_N(), s.get_nonempty_B()) for s in self.levels]), id(self))
177
178
179
180
181
182
183
184

    def __copy__(self):
        return self.copy()

    def __deepcopy__(self, memo):
        g = copy.deepcopy(self.g, memo)
        return self.copy(g=g)

185
186
    def copy(self, g=None, bs=None, state_args=None, hstate_args=None,
             hentropy_args=None, sampling=None, **kwargs):
187
188
189
190
191
        r"""Copies the block state. The parameters override the state properties,
        and have the same meaning as in the constructor."""
        bs = self.get_bs() if bs is None else bs
        return NestedBlockState(self.g if g is None else g, bs,
                                base_type=type(self.levels[0]),
192
                                state_args=self.state_args if state_args is None else state_args,
193
194
195
                                hstate_args=self.hstate_args if hstate_args is None else hstate_args,
                                hentropy_args=self.hentropy_args if hentropy_args is None else hentropy_args,
                                sampling=self.sampling if sampling is None else sampling,
196
                                **kwargs)
197
198

    def __getstate__(self):
199
200
        state = dict(g=self.g, bs=self.get_bs(), base_type=type(self.levels[0]),
                     hstate_args=self.hstate_args,
Tiago Peixoto's avatar
Tiago Peixoto committed
201
                     hentropy_args=self.hentropy_args, sampling=self.sampling,
202
                     state_args=self.state_args)
203
204
205
206
        return state

    def __setstate__(self, state):
        conv_pickle_state(state)
207
208
209
210
        if "kwargs" in state: # backwards compatibility
            state["state_args"] = state["kwargs"]
            del  state["kwargs"]
        self.__init__(**state)
211

212
213
214
215
216
217
    def get_bs(self):
        """Get hierarchy levels as a list of :class:`numpy.ndarray` objects with the
        group memberships at each level.
        """
        return [s.b.fa for s in self.levels]

218
219
220
221
    def get_state(self):
        """Alias to :meth:`~NestedBlockState.get_bs`."""
        return self.get_bs()

222
223
224
225
226
    def set_state(self, bs):
        r"""Sets the internal nested partition of the state."""
        for i in range(len(bs)):
            self.levels[i].set_state(bs[i])

227
    def get_levels(self):
228
        """Get hierarchy levels as a list of :class:`~graph_tool.inference.blockmodel.BlockState`
229
230
231
        instances."""
        return self.levels

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def project_partition(self, j, l):
        """Project partition of level ``j`` onto level ``l``, and return it."""
        b = self.levels[l].b.copy()
        for i in range(l + 1, j + 1):
            clabel = self.levels[i].b.copy()
            pmap(b, clabel)
        return b

    def propagate_clabel(self, l):
        """Project base clabel to level ``l``."""
        clabel = self.levels[0].clabel.copy()
        for j in range(l):
            bg = self.levels[j].bg
            bclabel = bg.new_vertex_property("int")
            reverse_map(self.levels[j].b, bclabel)
            pmap(bclabel, clabel)
            clabel = bclabel
        return clabel

    def get_clabel(self, l):
        """Get clabel for level ``l``."""
        clabel = self.propagate_clabel(l)
        if l < len(self.levels) - 1:
            b = self.project_partition(l + 1, l)
            clabel.fa += (clabel.fa.max() + 1) * b.fa
        return clabel

    def _consistency_check(self):
        for l in range(1, len(self.levels)):
            b = self.levels[l].b.fa.copy()
            state = self.levels[l-1]
263
264
265
266
            args = self.hstate_args
            if l == len(self.levels) - 1:
                args = dict(args, clabel=None, pclabel=None)
            bstate = state.get_block_state(b=b, **args)
267
268
269
270
            b2 = bstate.b.fa.copy()
            continuous_map(b)
            continuous_map(b2)
            assert ((b == b2).all() and
271
272
273
                    math.isclose(bstate.entropy(dl=False),
                                 self.levels[l].entropy(dl=False),
                                 abs_tol=1e-8)), \
274
275
276
                "inconsistent level %d (%s %g,  %s %g): %s" % \
                (l, str(bstate), bstate.entropy(), str(self.levels[l]),
                 self.levels[l].entropy(), str(self))
277
278
            assert (bstate.get_N() >= bstate.get_nonempty_B()), \
                (l, bstate.get_N(), bstate.get_nonempty_B(), str(self))
279
280
281
282
283
284
285
286
287
288
289

    def replace_level(self, l, b):
        """Replace level ``l`` given the new partition ``b``"""

        if l < len(self.levels) - 1:
            clabel = self.project_partition(l + 1, l)
        self.levels[l] = self.levels[l].copy(b=b)
        if l < len(self.levels) - 1:
            bclabel = self.levels[l].bg.new_vertex_property("int")
            reverse_map(self.levels[l].b, bclabel)
            pmap(bclabel, clabel)
290
291
            bstate = self.levels[l].get_block_state(b=bclabel,
                                                    **self.hstate_args)
292
293
            self.levels[l + 1] = bstate

294
295
        self._regen_Lrecdx()

296
297
298
299
300
301
302
303
        if _bm_test():
            self._consistency_check()

    def delete_level(self, l):
        """Delete level ``l``."""
        if l == 0:
            raise ValueError("cannot delete level l=0")
        b = self.project_partition(l, l - 1)
304
        self.replace_level(l - 1, b.fa)
305
306
        del self.levels[l]

307
308
        self._regen_Lrecdx()

309
310
311
312
313
314
315
        if _bm_test():
            self._consistency_check()

    def duplicate_level(self, l):
        """Duplicate level ``l``."""
        bstate = self.levels[l].copy(b=self.levels[l].g.vertex_index.copy("int").fa)
        self.levels.insert(l, bstate)
316
        self._regen_Lrecdx()
317
318
319
        if _bm_test():
            self._consistency_check()

320
    def level_entropy(self, l, bstate=None, **kwargs):
321
322
323
324
325
        """Compute the entropy of level ``l``."""

        if bstate is None:
            bstate = self.levels[l]

326
327
328
329
330
331
332
        kwargs = kwargs.copy()
        hentropy_args = dict(self.hentropy_args,
                             **kwargs.pop("hentropy_args", {}))
        hentropy_args_top = dict(dict(hentropy_args, edges_dl=True,
                                      recs_dl=True),
                                 **kwargs.pop("hentropy_args_top", {}))

333
        if l > 0:
334
335
336
337
            if l == (len(self.levels) - 1):
                eargs = hentropy_args_top
            else:
                eargs = hentropy_args
338
        else:
339
            eargs = dict(kwargs, edges_dl=False)
340

341
        S = bstate.entropy(**eargs)
342
343
344
345

        if l > 0:
            S *= kwargs.get("beta_dl", 1.)

346
347
        return S

348
    def _Lrecdx_entropy(self, Lrecdx=None):
349
350
        if self.base_type is not LayeredBlockState:
            S_D = 0
351

352
353
354
355
356
357
            if Lrecdx is None:
                Lrecdx = self.Lrecdx
                for s in self.levels:
                    B_E_D = s._state.get_B_E_D()
                    if B_E_D > 0:
                        S_D -= log(B_E_D)
358

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            S = 0
            for i in range(len(self.levels[0].rec)):
                if self.levels[0].rec_types[i] != libinference.rec_type.real_normal:
                    continue
                assert not _bm_test() or Lrecdx[i+1] >= 0, (i, Lrecdx[i+1])
                S += -libinference.positive_w_log_P(Lrecdx[0], Lrecdx[i+1],
                                                    numpy.nan, numpy.nan,
                                                    self.levels[0].epsilon[i])
                S += S_D
            return S
        else:
            S_D = [0 for j in range(self.levels[0].C)]
            if Lrecdx is None:
                Lrecdx = self.Lrecdx
                for s in self.levels:
                    for j in range(self.levels[0].C):
                        B_E_D = s._state.get_layer(j).get_B_E_D()
                        if B_E_D > 0:
                            S_D[j] -= log(B_E_D)

            S = 0
            for i in range(len(self.levels[0].rec)):
                if self.levels[0].rec_types[i] != libinference.rec_type.real_normal:
                    continue
                for j in range(self.levels[0].C):
                    assert not _bm_test() or Lrecdx[j+1][i+1] >= 0, (i, j, Lrecdx[j+1][i+1])
                    S += -libinference.positive_w_log_P(Lrecdx[j+1][0],
                                                        Lrecdx[j+1][i+1],
                                                        numpy.nan, numpy.nan,
                                                        self.levels[0].epsilon[i])
                    S += S_D[j]
            return S
391

392

393
    def entropy(self, **kwargs):
Tiago Peixoto's avatar
Tiago Peixoto committed
394
395
        """Compute the entropy of whole hierarchy.

396
397
        The keyword arguments are passed to the ``entropy()`` method of the
        underlying state objects
398
399
400
        (e.g. :class:`graph_tool.inference.blockmodel.BlockState.entropy`,
        :class:`graph_tool.inference.overlap_blockmodel.OverlapBlockState.entropy`, or
        :class:`graph_tool.inference.layered_blockmodel.LayeredBlockState.entropy`).  """
401
402
        S = 0
        for l in range(len(self.levels)):
403
            S += self.level_entropy(l, **dict(kwargs, test=False))
404

405
        S += kwargs.get("beta_dl", 1.) * self._Lrecdx_entropy()
406
407
408
409
410

        if _bm_test() and kwargs.pop("test", True):
            state = self.copy()
            Salt = state.entropy(test=False, **kwargs)
            assert math.isclose(S, Salt, abs_tol=1e-8), \
411
                "inconsistent entropy after copying (%g, %g, %g): %s" % \
412
                (S, Salt, S-Salt, str(kwargs))
413

414
415
        return S

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    def move_vertex(self, v, s):
        r"""Move vertex ``v`` to block ``s``."""
        self.levels[0].move_vertex(v, s)
        self._regen_levels()

    def remove_vertex(self, v):
        r"""Remove vertex ``v`` from its current group.

        This optionally accepts a list of vertices to remove.

        .. warning::

           This will leave the state in an inconsistent state before the vertex
           is returned to some other group, or if the same vertex is removed
           twice.
        """
        self.levels[0].remove_vertex(v)
        self._regen_levels()

    def add_vertex(self, v, r):
        r"""Add vertex ``v`` to block ``r``.

        This optionally accepts a list of vertices and blocks to add.

        .. warning::

           This can leave the state in an inconsistent state if a vertex is
           added twice to the same group.
        """
        self.levels[0].add_vertex(v, r)
        self._regen_levels()

448
    def get_edges_prob(self, missing, spurious=[], entropy_args={}):
449
        r"""Compute the joint log-probability of the missing and spurious edges given by
450
451
452
453
454
455
456
457
        ``missing`` and ``spurious`` (a list of ``(source, target)``
        tuples, or :meth:`~graph_tool.Edge` instances), together with the
        observed edges.

        More precisely, the log-likelihood returned is

        .. math::

458
            \ln \frac{P(\boldsymbol G + \delta \boldsymbol G | \boldsymbol b)}{P(\boldsymbol G| \boldsymbol b)}
459
460
461
462
463

        where :math:`\boldsymbol G + \delta \boldsymbol G` is the modified graph
        (with missing edges added and spurious edges deleted).

        The values in ``entropy_args`` are passed to
464
        :meth:`graph_tool.inference.blockmodel.BlockState.entropy()` to calculate the
465
466
        log-probability.
        """
467

468
469
470
471
472
473
474
        entropy_args = entropy_args.copy()
        hentropy_args = dict(self.hentropy_args,
                             **entropy_args.pop("hentropy_args", {}))
        hentropy_args_top = dict(dict(hentropy_args, edges_dl=True,
                                      recs_dl=True),
                                 **entropy_args.pop("hentropy_args_top", {}))

475
        L = 0
476
        for l, lstate in enumerate(self.levels):
477
            if l > 0:
478
479
480
481
                if l == (len(self.levels) - 1):
                    eargs = hentropy_args_top
                else:
                    eargs = hentropy_args
482
483
484
485
486
487
488
            else:
                eargs = entropy_args

            if self.sampling:
                lstate._couple_state(None, None)
                if l > 0:
                    lstate._state.sync_emat()
489
                    lstate._state.clear_egroups()
490

491
            L += lstate.get_edges_prob(missing, spurious, entropy_args=eargs)
492
            if isinstance(self.levels[0], LayeredBlockState):
493
494
                missing = [(lstate.b[u], lstate.b[v], l_) for u, v, l_ in missing]
                spurious = [(lstate.b[u], lstate.b[v], l_) for u, v, l_ in spurious]
495
            else:
496
497
498
                missing = [(lstate.b[u], lstate.b[v]) for u, v in missing]
                spurious = [(lstate.b[u], lstate.b[v]) for u, v in spurious]

499
500
        return L

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    def get_bstack(self):
        """Return the nested levels as individual graphs.

        This returns a list of :class:`~graph_tool.Graph` instances
        representing the inferred hierarchy at each level. Each graph has two
        internal vertex and edge property maps named "count" which correspond to
        the vertex and edge counts at the lower level, respectively. Additionally,
        an internal vertex property map named "b" specifies the block partition.
        """

        bstack = []
        for l, bstate in enumerate(self.levels):
            cg = bstate.g
            if l == 0:
                cg = GraphView(cg, skip_properties=True)
            cg.vp["b"] = bstate.b.copy()
517
518
519
            if bstate.is_weighted:
                cg.ep["count"] = cg.own_property(bstate.eweight.copy())
                cg.vp["count"] = cg.own_property(bstate.vweight.copy())
520
521
522
523
            else:
                cg.ep["count"] = cg.new_ep("int", 1)

            bstack.append(cg)
524
            if bstate.get_N() == 1:
525
526
527
528
529
530
531
                break
        return bstack

    def project_level(self, l):
        """Project the partition at level ``l`` onto the lowest level, and return the
        corresponding state."""
        b = self.project_partition(l, 0)
532
        return self.levels[0].copy(b=b)
533
534
535
536

    def print_summary(self):
        """Print a hierarchy summary."""
        for l, state in enumerate(self.levels):
537
538
            print("l: %d, N: %d, B: %d" % (l, state.get_N(),
                                           state.get_nonempty_B()))
539

540
541
    def find_new_level(self, l, bisection_args={}, B_min=None, B_max=None,
                       b_min=None, b_max=None):
542
        """Attempt to find a better network partition at level ``l``, using
543
        :func:`~graph_tool.inference.bisection.bisection_minimize` with arguments given by
544
545
546
547
548
549
550
551
        ``bisection_args``.
        """

        # assemble minimization arguments
        mcmc_multilevel_args = bisection_args.get("mcmc_multilevel_args", {})
        mcmc_equilibrate_args = mcmc_multilevel_args.get("mcmc_equilibrate_args", {})
        mcmc_args = mcmc_equilibrate_args.get("mcmc_args", {})
        entropy_args = mcmc_args.get("entropy_args", {})
552
        if l > 0:
553
            entropy_args = dict(entropy_args, **self.hentropy_args)
554
555
556
        top = (l == (len(self.levels) - 1))
        entropy_args = dict(entropy_args, edges_dl=top, recs_dl=top)

557
558
        def callback(s):
            S = 0
559
            bstate = None
560
            if l < len(self.levels) - 1:
561
562
563
                if s._coupled_state is None:
                    bclabel = s.get_bclabel()
                    bstate = s.get_block_state(b=bclabel,
564
565
                                               **dict(self.hstate_args,
                                                      Lrecdx=s.Lrecdx))
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
                    S += bstate.entropy(**dict(self.hentropy_args,
                                               edges_dl=(l + 1 == len(self.levels) - 1),
                                               recs_dl=(l + 1 == len(self.levels) - 1)))
                else:
                    bstate = s._coupled_state[0]
                    S += bstate.entropy(**dict(s._coupled_state[1], recs=True))

            if self.base_type is not LayeredBlockState:
                if s.Lrecdx[0] >= 0:
                    S += self._Lrecdx_entropy(s.Lrecdx)
                    ss = s
                    while ss is not None:
                        B_E_D = ss._state.get_B_E_D()
                        if B_E_D > 0:
                            for i in range(len(s.rec)):
                                if s.rec_types[i] != libinference.rec_type.real_normal:
                                    continue
                                S -= log(B_E_D)
                        if l < len(self.levels) - 1 and ss is not bstate:
                            ss = bstate
                        else:
                            ss = None
            else:
                if s.Lrecdx[0][0] >= 0:
                    S += self._Lrecdx_entropy(s.Lrecdx)
                    ss = s
                    while ss is not None:
                        for j in range(len(ss.layer_states)):
                            B_E_D = ss._state.get_layer(j).get_B_E_D()
                            if B_E_D > 0:
                                for i in range(len(s.rec)):
                                    if s.rec_types[i] != libinference.rec_type.real_normal:
                                        continue
                                    S -= log(B_E_D)
                        if l < len(self.levels) - 1 and ss is not bstate:
                            ss = bstate
                        else:
                            ss = None

            assert (not _bm_test() or bstate is None or
                    s.get_nonempty_B() == bstate.get_N()), (s.get_nonempty_B(),
                                                            bstate.get_N())
608
            return S
609

610
        entropy_args = dict(entropy_args, callback=callback)
611
        mcmc_args = dict(mcmc_args, entropy_args=entropy_args)
612
613
        if l > 0:
            mcmc_args = dmask(mcmc_args, ["bundled"])
614
615
        mcmc_equilibrate_args = dict(mcmc_equilibrate_args,
                                     mcmc_args=mcmc_args)
616
617
618
619
620
621
622
623
624
        shrink_args = mcmc_multilevel_args.get("shrink_args", {})
        shrink_args = dict(shrink_args,
                           entropy_args=dict(shrink_args.get("entropy_args", {}),
                                             **entropy_args))
        if l > 0:
            shrink_args["entropy_args"].update(dict(multigraph=True, dense=True))
        elif not shrink_args["entropy_args"].get("dense", False):
            shrink_args["entropy_args"]["multigraph"] = False

625
        mcmc_multilevel_args = dict(mcmc_multilevel_args,
626
                                    shrink_args=shrink_args,
627
628
                                    mcmc_equilibrate_args=mcmc_equilibrate_args)
        bisection_args = dict(bisection_args,
629
                              mcmc_multilevel_args=mcmc_multilevel_args)
630
631
632
633
634

        # construct boundary states and constraints
        clabel = self.get_clabel(l)
        state = self.levels[l]
        if b_max is None:
635
            b_max = state.g.vertex_index.copy("int").fa
636
        else:
637
638
639
640
            b_max = state.g.new_vp("int", b_max)
            b_max = group_vector_property([b_max, clabel])
            b_max = perfect_prop_hash([b_max])[0].fa
        continuous_map(b_max)
641
642
643
        max_state = state.copy(b=b_max, clabel=clabel,
                               recs=[r.copy() for r in state.rec],
                               drec=[r.copy() for r in state.drec])
644
645
        max_Lrecdx = self._regen_Lrecdx(lstate=(l, max_state))
        max_state = max_state.copy(Lrecdx=max_Lrecdx)
646
647
648
        if B_max is not None and max_state.B > B_max:
            max_state = mcmc_multilevel(max_state, B_max,
                                        **mcmc_multilevel_args)
649

650
651
        if l < len(self.levels) - 1:
            if B_min is None:
652
653
654
                min_state = state.copy(b=clabel.fa, clabel=clabel.fa,
                                       recs=[r.copy() for r in state.rec],
                                       drec=[r.copy() for r in state.drec])
655
                B_min = min_state.B
656
            else:
657
                B_min = max(B_min, clabel.fa.max() + 1)
658
659
660
                min_state = mcmc_multilevel(max_state, B_min,
                                            **mcmc_multilevel_args)
            if _bm_test():
661
662
663
                assert (min_state.B == self.levels[l+1].B or
                        min_state.B == B_min), (B_min, min_state.B,
                                                self.levels[l+1].B)
664
        else:
665
            min_state = state.copy(b=clabel.fa, clabel=clabel.fa)
666
667
        min_Lrecdx = self._regen_Lrecdx(lstate=(l, min_state))
        min_state = min_state.copy(Lrecdx=min_Lrecdx)
668
669
670
671
        if B_min is not None and  min_state.B > B_min:
            min_state = mcmc_multilevel(min_state, B_min,
                                        **mcmc_multilevel_args)

672
        if l < len(self.levels) - 1:
673
            eargs = dict(self.hentropy_args,
674
                         edges_dl=(l + 1 == len(self.levels) - 1))
675
676
677
678
679
680
681
682
683
684
685
686
687
688
            min_state._couple_state(min_state.get_block_state(**dict(self.hstate_args,
                                                                     b=min_state.get_bclabel(),
                                                                     copy_bg=False,
                                                                     Lrecdx=min_state.Lrecdx)),
                                    eargs)
            max_state._couple_state(max_state.get_block_state(**dict(self.hstate_args,
                                                                     b=max_state.get_bclabel(),
                                                                     copy_bg=False,
                                                                     Lrecdx=max_state.Lrecdx)),
                                    eargs)

        if _bm_test():
            assert min_state._check_clabel(), "invalid clabel %s" % str((l, self))
            assert max_state._check_clabel(), "invalid clabel %s" % str((l, self))
689

690
691
692
693
694
        # find new state
        state = bisection_minimize([min_state, max_state], **bisection_args)

        if _bm_test():
            assert state.B >= min_state.B, (l, state.B, min_state.B, str(self))
695
            assert state._check_clabel(), "invalid clabel %s" % str((l, self))
696
697

        state._couple_state(None, None)
698
699
        return state

700
701
702
    def _couple_levels(self, hentropy_args, hentropy_args_top):
        if hentropy_args_top is None:
            hentropy_args_top = dict(hentropy_args, edges_dl=True, recs_dl=True)
703
        for l in range(len(self.levels) - 1):
704
705
706
707
            if l + 1 == len(self.levels) - 1:
                eargs = hentropy_args_top
            else:
                eargs = hentropy_args
708
709
            self.levels[l]._couple_state(self.levels[l + 1], eargs)

710
711
712
713
    def _clear_egroups(self):
        for lstate in self.levels:
            lstate._clear_egroups()

714
    def _h_sweep_gen(self, **kwargs):
715
716
717
718

        if not self.sampling:
            raise ValueError("NestedBlockState must be constructed with 'sampling=True'")

719
        verbose = kwargs.get("verbose", False)
720
721
722
723
724
725
        entropy_args = dict(kwargs.get("entropy_args", {}), edges_dl=False)
        hentropy_args = dict(self.hentropy_args,
                             **entropy_args.pop("hentropy_args", {}))
        hentropy_args_top = dict(dict(hentropy_args, edges_dl=True,
                                      recs_dl=True),
                                 **entropy_args.pop("hentropy_args_top", {}))
726

727
        self._couple_levels(hentropy_args, hentropy_args_top)
728

729
730
        c = kwargs.get("c", None)

731
        lrange = list(kwargs.pop("ls", range(len(self.levels))))
732
733
        if kwargs.pop("ls_shuffle", True):
            numpy.random.shuffle(lrange)
734
        for l in lrange:
735
736
737
            if check_verbose(verbose):
                print(verbose_pad(verbose) + "level:", l)
            if l > 0:
738
739
740
741
                if l == len(self.levels) - 1:
                    eargs = hentropy_args_top
                else:
                    eargs = hentropy_args
742
743
744
            else:
                eargs = entropy_args

745
            if c is None:
746
                args = dict(kwargs, entropy_args=eargs)
747
            else:
748
                args = dict(kwargs, entropy_args=eargs, c=c[l])
749

750
751
            if l > 0 and "beta_dl" in entropy_args:
                args = dict(args, beta=args.get("beta", 1.) * entropy_args["beta_dl"])
752

753
754
755
756
757
758
759
760
761
            yield l, self.levels[l], args

    def _h_sweep(self, algo, **kwargs):
        entropy_args = kwargs.get("entropy_args", {})

        dS = 0
        nattempts = 0
        nmoves = 0

762
763
        try:
            for l, lstate, args in self._h_sweep_gen(**kwargs):
764

765
                ret = algo(self.levels[l], **args)
766

767
768
769
770
771
772
773
774
775
                if l > 0 and "beta_dl" in entropy_args:
                    dS += ret[0] * entropy_args["beta_dl"]
                else:
                    dS += ret[0]
                nattempts += ret[1]
                nmoves += ret[2]
        finally:
            for state in self.levels:
                state.B = state.bg.num_vertices()
776

777
        return dS, nattempts, nmoves
778

779
780
781
    def _h_sweep_states(self, algo, **kwargs):
        entropy_args = kwargs.get("entropy_args", {})
        for l, lstate, args in self._h_sweep_gen(**kwargs):
782
783
            beta_dl = entropy_args.get("beta_dl", 1) if l > 0 else 1
            yield l, lstate, algo(self.levels[l], dispatch=False, **args), beta_dl
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

    def _h_sweep_parallel_dispatch(states, sweeps, algo):
        ret = None
        for lsweep in zip(*sweeps):
            ls = [x[0] for x in lsweep]
            lstates = [x[1] for x in lsweep]
            lsweep_states = [x[2] for x in lsweep]
            beta_dl = [x[3] for x in lsweep]
            lret = algo(type(lstates[0]), lstates, lsweep_states)
            if ret is None:
                ret = lret
            else:
                ret = [(ret[i][0] + lret[i][0] * beta_dl[i],
                        ret[i][1] + lret[i][1],
                        ret[i][2] + lret[i][2]) for i in range(len(lret))]
        return ret

801
802
    def mcmc_sweep(self, **kwargs):
        r"""Perform ``niter`` sweeps of a Metropolis-Hastings acceptance-rejection
Tiago Peixoto's avatar
Tiago Peixoto committed
803
        MCMC to sample hierarchical network partitions.
804
805

        The arguments accepted are the same as in
806
        :meth:`graph_tool.inference.blockmodel.BlockState.mcmc_sweep`.
807
808
809
810
811

        If the parameter ``c`` is a scalar, the values used at each level are
        ``c * 2 ** l`` for ``l`` in the range ``[0, L-1]``. Optionally, a list
        of values may be passed instead, which specifies the value of ``c[l]``
        to be used at each level.
812
        """
813

814
        c = kwargs.pop("c", 1)
Tiago Peixoto's avatar
Tiago Peixoto committed
815
        if not isinstance(c, collections.Iterable):
816
            c = [c * 2 ** l for l in range(0, len(self.levels))]
Tiago Peixoto's avatar
Tiago Peixoto committed
817

818
819
820
821
822
        if kwargs.pop("dispatch", True):
            if _bm_test():
                kwargs = dict(kwargs, test=False)
                entropy_args = kwargs.get("entropy_args", {})
                Si = self.entropy(**entropy_args)
823

824
825
            dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.mcmc_sweep(**a),
                                                  c=c, **kwargs)
826

827
828
829
830
831
832
833
834
835
836
837
838
839
            if _bm_test():
                Sf = self.entropy(**entropy_args)
                assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
                    "inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
                                                                str(entropy_args))
            return dS, nattempts, nmoves
        else:
            return self._h_sweep_states(lambda s, **a: s.mcmc_sweep(**a),
                                        c=c, **kwargs)

    def _mcmc_sweep_parallel_dispatch(states, sweeps):
        algo = lambda s, lstates, lsweep_states: s._mcmc_sweep_parallel_dispatch(lstates, lsweep_states)
        return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)
840

841
842
843
844
845
    def multiflip_mcmc_sweep(self, **kwargs):
        r"""Perform ``niter`` sweeps of a Metropolis-Hastings acceptance-rejection MCMC
        with multiple moves to sample hierarchical network partitions.

        The arguments accepted are the same as in
846
        :meth:`graph_tool.inference.blockmodel.BlockState.multiflip_mcmc_sweep`.
847
848
849
850
851
852
853
854

        If the parameter ``c`` is a scalar, the values used at each level are
        ``c * 2 ** l`` for ``l`` in the range ``[0, L-1]``. Optionally, a list
        of values may be passed instead, which specifies the value of ``c[l]``
        to be used at each level.

        """

855
856
857
        psingle = kwargs.get("psingle", self.g.num_vertices())
        kwargs["psingle"] = psingle

858
859
        c = kwargs.pop("c", 1)
        if not isinstance(c, collections.Iterable):
860
            c = [c * 2 ** l for l in range(0, len(self.levels))]
861

862
863
864
865
866
        if kwargs.pop("dispatch", True):
            if _bm_test():
                kwargs = dict(kwargs, test=False)
                entropy_args = kwargs.get("entropy_args", {})
                Si = self.entropy(**entropy_args)
867

868
869
870
871
872
            dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.multiflip_mcmc_sweep(**a),
                                                  c=c, **kwargs)
            if _bm_test():
                Sf = self.entropy(**entropy_args)
                assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
873
874
                    r"inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
                                                                 str(entropy_args))
875
876
877
878
879
880
881
882
            return dS, nattempts, nmoves
        else:
            return self._h_sweep_states(lambda s, **a: s.multiflip_mcmc_sweep(**a),
                                        c=c, **kwargs)

    def _multiflip_mcmc_sweep_parallel_dispatch(states, sweeps):
        algo = lambda s, lstates, lsweep_states: s._multiflip_mcmc_sweep_parallel_dispatch(lstates, lsweep_states)
        return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)
883

884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    def gibbs_sweep(self, **kwargs):
        r"""Perform ``niter`` sweeps of a rejection-free Gibbs sampling MCMC
        to sample network partitions.

        The arguments accepted are the same as in
        :meth:`graph_tool.inference.blockmodel.BlockState.gibbs_sweep`.
        """
        if _bm_test():
            kwargs = dict(kwargs, test=False)
            entropy_args = kwargs.get("entropy_args", {})
            Si = self.entropy(**entropy_args)

        dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.gibbs_sweep(**a),
                                              **kwargs)

        if _bm_test():
            Sf = self.entropy(**entropy_args)
            assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
                "inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
                                                            str(entropy_args))
        return dS, nattempts, nmoves

    def _gibbs_sweep_parallel_dispatch(states, sweeps):
        algo = lambda s, lstates, lsweep_states: s._gibbs_sweep_parallel_dispatch(lstates, lsweep_states)
        return NestedBlockState._h_sweep_parallel_dispatch(states, sweeps, algo)

910
911
912
913
914
    def multicanonical_sweep(self, **kwargs):
        r"""Perform ``niter`` sweeps of a non-Markovian multicanonical sampling using the
        Wang-Landau algorithm.

        The arguments accepted are the same as in
915
        :meth:`graph_tool.inference.blockmodel.BlockState.multicanonical_sweep`.
916
        """
917
        if _bm_test():
918
            kwargs = dict(kwargs, test=False)
919
920
921
            entropy_args = kwargs.get("entropy_args", {})
            Si = self.entropy(**entropy_args)

922
        dS, nattempts, nmoves = self._h_sweep(lambda s, **a: s.multicanonical_sweep(**a))
923
924
925
926
927
928

        if _bm_test():
            Sf = self.entropy(**entropy_args)
            assert math.isclose(dS, (Sf - Si), abs_tol=1e-8), \
                "inconsistent entropy delta %g (%g): %s" % (dS, Sf - Si,
                                                            str(entropy_args))
929
        return dS, nattempts, nmoves
930

931
932
933
934
    def collect_partition_histogram(self, h=None, update=1):
        r"""Collect a histogram of partitions.

        This should be called multiple times, e.g. after repeated runs of the
935
        :meth:`graph_tool.inference.nested_blockmodel.NestedBlockState.mcmc_sweep` function.
936
937
938

        Parameters
        ----------
939
        h : :class:`~graph_tool.inference.blockmodel.PartitionHist` (optional, default: ``None``)
940
941
942
943
944
945
946
            Partition histogram. If not provided, an empty histogram will be created.
        update : float (optional, default: ``1``)
            Each call increases the current count by the amount given by this
            parameter.

        Returns
        -------
947
        h : :class:`~graph_tool.inference.blockmodel.PartitionHist` (optional, default: ``None``)
948
949
950
951
952
953
954
955
956
957
            Updated Partition histogram.

        """

        if h is None:
            h = PartitionHist()
        bs = [_prop("v", state.g, state.b) for state in self.levels]
        libinference.collect_hierarchical_partitions(bs, h, update)
        return h

958
959
960
961
962
963
964
965
966
    def draw(self, **kwargs):
        r"""Convenience wrapper to :func:`~graph_tool.draw.draw_hierarchy` that
        draws the hierarchical state."""
        import graph_tool.draw
        return graph_tool.draw.draw_hierarchy(self, **kwargs)



def hierarchy_minimize(state, B_min=None, B_max=None, b_min=None, b_max=None,
967
                       frozen_levels=None, bisection_args={},
968
                       epsilon=1e-8, verbose=False):
969
970
971
972
973
    """Attempt to find a fit of the nested stochastic block model that minimizes the
    description length.

    Parameters
    ----------
974
    state : :class:`~graph_tool.inference.nested_blockmodel.NestedBlockState`
975
976
977
978
979
        The nested block state.
    B_min : ``int`` (optional, default: ``None``)
        The minimum number of blocks.
    B_max : ``int`` (optional, default: ``None``)
        The maximum number of blocks.
980
    b_min : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
981
        The partition to be used with the minimum number of blocks.
982
    b_max : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
983
        The partition to be used with the maximum number of blocks.
Tiago Peixoto's avatar
Tiago Peixoto committed
984
    frozen_levels : sequence of ``int`` values (optional, default: ``None``)
985
986
        List of hierarchy levels that are kept constant during the minimization.
    bisection_args : ``dict`` (optional, default: ``{}``)
987
        Arguments to be passed to :func:`~graph_tool.inference.bisection.bisection_minimize`.
988
989
990
    epsilon: ``float`` (optional, default: ``1e-8``)
        Only replace levels if the description length difference is above this
        threshold.
991
992
993
994
995
996
997
998
999
    verbose : ``bool`` or ``tuple`` (optional, default: ``False``)
        If ``True``, progress information will be shown. Optionally, this
        accepts arguments of the type ``tuple`` of the form ``(level, prefix)``
        where ``level`` is a positive integer that specifies the level of
        detail, and ``prefix`` is a string that is prepended to the all output
        messages.

    Returns
    -------
1000
    min_state : :class:`~graph_tool.inference.nested_blockmodel.NestedBlockState`
For faster browsing, not all history is shown. View entire blame