layered_blockmodel.py 49.6 KB
Newer Older
1
2
3
4
5
#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# graph_tool -- a general graph manipulation python module
#
Tiago Peixoto's avatar
Tiago Peixoto committed
6
# Copyright (C) 2006-2020 Tiago de Paula Peixoto <tiago@skewed.de>
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

Alex Henrie's avatar
Alex Henrie committed
21
from .. import _prop, Graph, GraphView, libcore, _get_rng
22
23
24
from numpy import *
import numpy
import copy
25
import warnings
26

Alex Henrie's avatar
Alex Henrie committed
27
from .. import group_vector_property, ungroup_vector_property, perfect_prop_hash
28
29
30
31
32
33
34
35
36
37

from .. dl_import import dl_import
dl_import("from . import libgraph_tool_inference as libinference")

from .. generation import graph_union

from . blockmodel import *
from . blockmodel import _bm_test
from . overlap_blockmodel import *

Tiago Peixoto's avatar
Tiago Peixoto committed
38
class LayeredBlockState(OverlapBlockState, BlockState):
39
40
41
42
43
44
45
    r"""The (possibly overlapping) block state of a given graph, where the edges are
    divided into discrete layers.

    Parameters
    ----------
    g : :class:`~graph_tool.Graph`
        Graph to be modelled.
46
47
48
49
    ec : :class:`~graph_tool.EdgePropertyMap`
        Edge property map containing discrete edge covariates that will split
        the network in discrete layers.
    recs : list of :class:`~graph_tool.EdgePropertyMap` instances (optional, default: ``[]``)
50
51
52
53
54
55
56
        List of real or discrete-valued edge covariates.
    rec_types : list of edge covariate types (optional, default: ``[]``)
        List of types of edge covariates. The possible types are:
        ``"real-exponential"``, ``"real-normal"``, ``"discrete-geometric"``,
        ``"discrete-poisson"`` or ``"discrete-binomial"``.
    rec_params : list of ``dict`` (optional, default: ``[]``)
        Model hyperparameters for edge covariates. This should a list of
57
        ``dict`` instances. See :class:`~graph_tool.inference.blockmodel.BlockState` for
58
        more details.
59
    eweight : :class:`~graph_tool.EdgePropertyMap` (optional, default: ``None``)
60
        Edge multiplicities (for multigraphs or block graphs).
61
    vweight : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
62
        Vertex multiplicities (for block graphs).
63
    b : :class:`~graph_tool.VertexPropertyMap` or :class:`numpy.ndarray` (optional, default: ``None``)
64
65
66
67
68
        Initial block labels on the vertices or half-edges. If not supplied, it
        will be randomly sampled.
    B : ``int`` (optional, default: ``None``)
        Number of blocks (or vertex groups). If not supplied it will be obtained
        from the parameter ``b``.
69
    clabel : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
70
71
        Constraint labels on the vertices. If supplied, vertices with different
        label values will not be clustered in the same group.
72
    pclabel : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
73
74
75
76
77
78
79
80
81
82
83
84
85
        Partition constraint labels on the vertices. This has the same
        interpretation as ``clabel``, but will be used to compute the partition
        description length.
    layers : ``bool`` (optional, default: ``False``)
        If ``layers == True``, the "independent layers" version of the model is
        used, instead of the "edge covariates" version.
    deg_corr : ``bool`` (optional, default: ``True``)
        If ``True``, the degree-corrected version of the blockmodel ensemble will
        be assumed, otherwise the traditional variant will be used.
    overlap : ``bool`` (optional, default: ``False``)
        If ``True``, the overlapping version of the model will be used.
    """

86
87
    def __init__(self, g, ec, eweight=None, vweight=None, recs=[], rec_types=[],
                 rec_params=[], b=None, B=None, clabel=None, pclabel=False,
88
                 layers=False, deg_corr=True, overlap=False, **kwargs):
89
90
91

        kwargs = kwargs.copy()

92
93
        self.g = g

94
        if kwargs.pop("ec_done", False) or ec is None:
95
96
97
98
            self.ec = ec
        else:
            self.ec = ec = perfect_prop_hash([ec], "int32_t")[0]

99
100
101
102
        if ec is not None:
            self.C = ec.fa.max() + 1
        else:
            self.C = len(kwargs.get("gs"))
103
104
        self.layers = layers

105
106
107
        if "dense_bg" in kwargs:
            del kwargs["dense_bg"]
        dense_bg = False
108
109
110
111
112
113
114

        if vweight is None:
            vweight = g.new_vp("int", 1)

        if eweight is None:
            eweight = g.new_ep("int", 1)

115
116
117
118
119
        self.Lrecdx = kwargs.pop("Lrecdx", [])
        while len(self.Lrecdx) < self.C + 1:
            self.Lrecdx.append(libcore.Vector_double(1))
            self.Lrecdx[-1][0] = -1

120
        if not overlap:
121
122
            kwargs = dmask(kwargs, ["base_g", "node_index", "eindex",
                                    "half_edges"])
123
124
            ldegs = kwargs.pop("degs", None)
            if ldegs is not None:
125
126
127
128
129
                tdegs = libinference.get_mapped_block_degs(self.g._Graph__graph,
                                                           ldegs, 0,
                                                           _prop("v", self.g,
                                                                 self.g.vertex_index.copy("int")))
            else:
130
                tdegs = None
131

132
            agg_state = BlockState(g, b=b, B=B,
133
134
                                   eweight=eweight, vweight=vweight, recs=recs,
                                   rec_types=rec_types, rec_params=rec_params,
135
                                   clabel=clabel, pclabel=pclabel,
136
                                   deg_corr=deg_corr,
137
                                   dense_bg=dense_bg, degs=tdegs,
138
                                   Lrecdx=self.Lrecdx[0],
139
                                   use_rmap=True,
140
                                   **dmask(kwargs, ["degs", "lweights", "gs"]))
141
        else:
142
            kwargs = dmask(kwargs, ["degs"])
143
            ldegs = None
144
145
            agg_state = OverlapBlockState(g, b=b, B=B, recs=recs,
                                          rec_types=rec_types,
146
147
                                          rec_params=rec_params, clabel=clabel,
                                          pclabel=pclabel, deg_corr=deg_corr,
148
                                          dense_bg=dense_bg,
149
                                          Lrecdx=self.Lrecdx[0],
150
                                          **dmask(kwargs, ["degs", "lweights",
151
                                                           "gs", "bfield"]))
152
153
            self.base_g = agg_state.base_g
            self.g = agg_state.g
154
155
            eweight = self.g.new_ep("int", 1)
            vweight = self.g.new_vp("int", 1)
156
157
            kwargs = dmask(kwargs, ["base_g", "node_index", "eindex",
                                    "half_edges"])
158

159
        self.agg_state = agg_state
160

161
162
        if overlap and self.ec is not None:
            self.base_ec = self.base_g.own_property(ec.copy())
163
            ec = agg_state.eindex.copy()
164
165
166
            pmap(ec, self.ec)
            self.ec = ec.copy("int")

167
168
        self.eweight = eweight
        self.vweight = vweight
169
        if not overlap:
170
            self.is_weighted = agg_state.is_weighted
171
        else:
172
173
174
175
            self.is_weighted = False

        self.rec = agg_state.rec
        self.drec = agg_state.drec
176
        self.rec_types = agg_state.rec_types
177
        self.rec_params = agg_state.rec_params
178
        self.epsilon = agg_state.epsilon
179
180
181
182
183

        self.b = agg_state.b
        self.B = agg_state.B
        self.clabel = agg_state.clabel
        self.pclabel = agg_state.pclabel
184
        self.bclabel = agg_state.bclabel
185
        self.hclabel = agg_state.hclabel
186
187
188
189
        if not overlap:
            self.bfield = agg_state.bfield
        else:
            self.bfield = None
190
191
192
193
194
195
196

        self.deg_corr = deg_corr
        self.overlap = overlap

        self.vc = self.g.new_vp("vector<int>")
        self.vmap = self.g.new_vp("vector<int>")

197
        self.gs = kwargs.pop("gs", [])
198
        self.block_map = libinference.bmap_t()
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        lweights = kwargs.pop("lweights", self.g.new_vp("vector<int>"))

        if len(self.gs) == 0:
            for l in range(0, self.C):
                u = Graph(directed=g.is_directed())
                u.vp["b"] = u.new_vp("int")
                u.vp["weight"] = u.new_vp("int")
                u.ep["weight"] = u.new_ep("int")
                u.gp["rec"] = u.new_gp("object", val=[u.new_ep("double") for i in range(len(self.rec))])
                u.gp["drec"] = u.new_gp("object", val=[u.new_ep("double") for i in range(len(self.drec))])
                u.vp["brmap"] = u.new_vp("int")
                u.vp["vmap"] = u.new_vp("int")
                self.gs.append(u)

            libinference.split_layers(self.g._Graph__graph,
                                      _prop("e", self.g, self.ec),
                                      _prop("v", self.g, self.b),
                                      [_prop("e", self.g, x) for x in self.rec],
                                      [_prop("e", self.g, x) for x in self.drec],
                                      _prop("e", self.g, self.eweight),
                                      _prop("v", self.g, self.vweight),
                                      _prop("v", self.g, self.vc),
                                      _prop("v", self.g, self.vmap),
                                      _prop("v", self.g, lweights),
                                      [u._Graph__graph for u in self.gs],
                                      [_prop("v", u, u.vp["b"]) for u in self.gs],
                                      [[_prop("e", u, x) for x in u.gp["rec"]] for u in self.gs],
                                      [[_prop("e", u, x) for x in u.gp["drec"]] for u in self.gs],
                                      [_prop("e", u, u.ep["weight"]) for u in self.gs],
                                      [_prop("v", u, u.vp["weight"]) for u in self.gs],
                                      self.block_map,
                                      [_prop("v", u, u.vp["brmap"]) for u in self.gs],
                                      [_prop("v", u, u.vp["vmap"]) for u in self.gs])
232
        else:
233
234
235
236
237
238
239
240
241
            libinference.split_groups(_prop("v", self.g, self.b),
                                      _prop("v", self.g, self.vc),
                                      _prop("v", self.g, self.vmap),
                                      [u._Graph__graph for u in self.gs],
                                      [_prop("v", u, u.vp["b"]) for u in self.gs],
                                      [_prop("v", u, u.vp["weight"]) for u in self.gs],
                                      self.block_map,
                                      [_prop("v", u, u.vp["brmap"]) for u in self.gs],
                                      [_prop("v", u, u.vp["vmap"]) for u in self.gs])
242
243
244
245
246
247
248
249

        if self.g.get_vertex_filter()[0] is not None:
            for u in self.gs:
                u.set_vertex_filter(u.new_vp("bool", True))

        self.master = not self.layers

        if not overlap:
250
251
            self.degs = agg_state.degs
            self.merge_map = agg_state.merge_map
252
253
254

        self.layer_states = []

255
        self.dense_bg = dense_bg
256
257
258
259
260
        self.bg = agg_state.bg
        self.wr = agg_state.wr
        self.mrs = agg_state.mrs
        self.mrp = agg_state.mrp
        self.mrm = agg_state.mrm
261
262
263
264
265
266
267
        self.brec = agg_state.brec
        self.bdrec = agg_state.bdrec
        self.rec_params = agg_state.rec_params
        self.wparams = agg_state.wparams
        self.epsilon = agg_state.epsilon
        self._entropy_args = agg_state._entropy_args
        self.recdx = agg_state.recdx
268
269
270
271
        self.candidate_blocks = agg_state.candidate_blocks
        self.candidate_pos = agg_state.candidate_pos
        self.empty_blocks = agg_state.empty_blocks
        self.empty_pos = agg_state.empty_pos
272
273

        self._coupled_state = None
274

275
276
277
        for l, u in enumerate(self.gs):
            state = self.__gen_state(l, u, ldegs)
            self.layer_states.append(state)
278

279
280
        if ec is None:
            self.ec = self.g.new_ep("int")
281
282
283

        if not self.overlap:
            self._state = \
284
                libinference.make_layered_block_state(agg_state._state,
285
286
287
                                                      self)
        else:
            self._state = \
288
                libinference.make_layered_overlap_block_state(agg_state._state,
289
                                                              self)
290
291
        if ec is None:
            self.ec = None
292
293
294
295

        if _bm_test():
            assert self.mrs.fa.sum() == self.eweight.fa.sum(), "inconsistent mrs!"

296
        kwargs.pop("recs", None)
297
298
        kwargs.pop("drec", None)
        kwargs.pop("rec_params", None)
299
        kwargs.pop("Lrecdx", None)
300
        kwargs.pop("epsilon", None)
301
        kwargs.pop("bfield", None)
302

303
        if len(kwargs) > 0:
304
305
            warnings.warn("unrecognized keyword arguments: " +
                          str(list(kwargs.keys())))
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    def get_N(self):
        r"Returns the total number of edges."
        return self.agg_state.get_N()

    def get_E(self):
        r"Returns the total number of nodes."
        return self.agg_state.get_E()

    def get_B(self):
        r"Returns the total number of blocks."
        return self.agg_state.get_B()

    def get_nonempty_B(self):
        r"Returns the total number of nonempty blocks."
        return self.agg_state.get_nonempty_B()

323
324
    def __get_base_u(self, u):
        node_index = u.vp["vmap"].copy("int64_t")
325
        pmap(node_index, self.agg_state.node_index)
326
327
328
329
330
331
332
333
334
335
336
337
338
        base_u, nindex, vcount, ecount = \
            condensation_graph(u, node_index,
                               self_loops=True,
                               parallel_edges=True)[:4]
        rindex = zeros(nindex.a.max() + 1, dtype="int64")
        reverse_map(nindex, rindex)
        pmap(node_index, rindex)
        base_u.vp["vmap"] = nindex
        return base_u, node_index

    def __gen_state(self, l, u, ldegs):
        B = u.num_vertices() + 1
        if not self.overlap:
339
            if ldegs is not None:
340
341
342
343
344
                degs = libinference.get_mapped_block_degs(u._Graph__graph,
                                                          ldegs, l + 1,
                                                           _prop("v", u,
                                                                 u.vp.vmap))
            else:
345
                degs = None
346
347
            state = BlockState(u, b=u.vp["b"],
                               B=B,
348
349
                               recs=u.gp["rec"],
                               drec=u.gp["drec"],
350
                               rec_types=self.rec_types,
351
                               rec_params=self.rec_params,
352
353
                               epsilon=self.epsilon,
                               Lrecdx=self.Lrecdx[l+1],
354
355
356
357
                               eweight=u.ep["weight"],
                               vweight=u.vp["weight"],
                               deg_corr=self.deg_corr,
                               degs=degs,
358
                               dense_bg=self.dense_bg,
359
                               use_rmap=True)
360
361
362
363
        else:
            base_u, node_index = self.__get_base_u(u)
            state = OverlapBlockState(u, b=u.vp["b"].fa,
                                      B=B,
364
365
                                      recs=u.gp["rec"],
                                      drec=u.gp["drec"],
366
                                      rec_types=self.rec_types,
367
                                      rec_params=self.rec_params,
368
369
                                      epsilon=self.epsilon,
                                      Lrecdx=self.Lrecdx[l+1],
370
371
372
                                      node_index=node_index,
                                      base_g=base_u,
                                      deg_corr=self.deg_corr,
373
                                      dense_bg=self.dense_bg)
374
375
376
377
378
379
380
        state.block_rmap = u.vp["brmap"]
        state.vmap = u.vp["vmap"]
        return state

    def __getstate__(self):
        state = dict(g=self.g,
                     ec=self.ec,
381
                     recs=self.rec,
382
                     drec=self.drec,
383
                     rec_types=self.rec_types,
384
                     rec_params=self.rec_params,
385
386
387
388
389
390
                     layers=self.layers,
                     eweight=self.eweight,
                     vweight=self.vweight,
                     b=self.b,
                     B=self.B,
                     clabel=self.clabel,
391
                     pclabel=self.pclabel,
392
                     bfield=self.bfield,
393
                     deg_corr=self.deg_corr)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        return state

    def __setstate__(self, state):
        self.__init__(**state)

    def __copy__(self):
        return self.copy()

    def __deepcopy__(self, memo):
        g = copy.deepcopy(self.g, memo)
        ec = g.own_property(copy.deepcopy(self.ec, memo))
        return self.copy(g=g, ec=ec)

    def copy(self, g=None, eweight=None, vweight=None, b=None, B=None,
408
409
             deg_corr=None, clabel=None, pclabel=None, bfield=None,
             overlap=None, layers=None, ec=None, **kwargs):
410
411
412
413
414
415
416
417
418
419
420
        r"""Copies the block state. The parameters override the state properties, and
         have the same meaning as in the constructor."""
        lweights = self.g.new_vp("vector<int>")
        degs = None
        if not self.overlap:
            libinference.get_lweights(self.g._Graph__graph,
                                      _prop("v", self.g, self.vc),
                                      _prop("v", self.g, self.vmap),
                                      _prop("v", self.g, lweights),
                                      [_prop("v", state.g, state.vweight)
                                       for state in self.layer_states])
421
            if not isinstance(self.agg_state.degs, libinference.simple_degs_t):
422
423
424
                degs = libinference.get_ldegs(self.g._Graph__graph,
                                              _prop("v", self.g, self.vc),
                                              _prop("v", self.g, self.vmap),
425
                                              [self.agg_state.degs] +
426
427
428
                                              [state.degs for state
                                               in self.layer_states])
            else:
429
                degs = None
430

431
        gs = [u.copy() for u in self.gs] if ec is None else []
432
        ec = self.ec if ec is None else ec
433
434
435
436
437
438
439
440
441
442
443

        if len(gs) > 0:
            libinference.get_rvmap(self.g._Graph__graph,
                                   _prop("v", self.g, self.vc),
                                   _prop("v", self.g, self.vmap),
                                   [_prop("v", u, u.vp.vmap) for u in gs])
            for u in gs:
                u.gp.rec = [u.own_property(x.copy()) for x in u.gp.rec]
                if u.gp.drec is not None:
                    u.gp.drec = [u.own_property(x.copy()) for x in u.gp.drec]

444
        state = LayeredBlockState(self.g if g is None else g,
445
                                  ec=ec, gs=gs,
446
447
                                  eweight=self.eweight if eweight is None else eweight,
                                  vweight=self.vweight if vweight is None else vweight,
448
449
450
451
                                  recs=kwargs.pop("recs", self.rec),
                                  drec=kwargs.pop("drec", self.drec),
                                  rec_types=kwargs.pop("rec_types", self.rec_types),
                                  rec_params=kwargs.pop("rec_params", self.rec_params),
452
453
454
455
                                  b=self.b if b is None else b,
                                  B=(self.B if b is None else None) if B is None else B,
                                  clabel=self.clabel.fa if clabel is None else clabel,
                                  pclabel=self.pclabel if pclabel is None else pclabel,
456
                                  bfield=self.bfield if bfield is None else bfield,
457
458
459
460
                                  deg_corr=self.deg_corr if deg_corr is None else deg_corr,
                                  overlap=self.overlap if overlap is None else overlap,
                                  layers=self.layers if layers is None else layers,
                                  base_g=self.base_g if self.overlap else None,
461
462
463
                                  half_edges=self.agg_state.half_edges if self.overlap else None,
                                  node_index=self.agg_state.node_index if self.overlap else None,
                                  eindex=self.agg_state.eindex if self.overlap else None,
464
                                  ec_done=ec is not None,
465
                                  degs=degs, lweights=lweights,
466
467
468
                                  Lrecdx=kwargs.pop("Lrecdx",
                                                    [x.copy() for x in self.Lrecdx]),
                                  epsilon=kwargs.pop("epsilon", self.epsilon.copy()),
469
                                  **kwargs)
470
471
472
473
474

        if self._coupled_state is not None:
            state._couple_state(state.get_block_state(b=state.get_bclabel(),
                                                      copy_bg=False,
                                                      vweight="nonempty",
475
                                                      Lrecdx=state.Lrecdx),
476
                                self._coupled_state[1])
477
478
479
        return state

    def __repr__(self):
480
        return "<LayeredBlockState object with %d %sblocks, %d %s,%s%s for graph %s, at 0x%x>" % \
481
482
            (self.B, "overlapping " if self.overlap else "",
             self.C, "layers" if self.layers else "edge covariates",
483
             " degree-corrected," if self.deg_corr else "",
484
485
486
             ((" with %d edge covariate%s," % (len(self.rec_types),
                                               "s" if len(self.rec_types) > 1 else ""))
              if len(self.rec_types) > 0 else ""),
487
488
489
490
491
492
             str(self.base_g if self.overlap else self.g), id(self))

    def get_bg(self):
        r"""Returns the block graph."""

        bg = Graph(directed=self.g.is_directed())
493
494
        mrs = bg.new_ep("int")
        ec = bg.new_ep("int")
495
496
        rec = bg.new_edge_property("vector<double>")
        drec = bg.new_edge_property("vector<double>")
497
498
499

        for l in range(self.C):
            u = GraphView(self.g, efilt=self.ec.a == l)
500
501
            ug = get_block_graph(u, self.B, self.b, self.vweight, self.eweight,
                                 rec=self.rec, drec=self.drec)
502
            uec = ug.new_ep("int")
503
            uec.a = l
504
505
506
507
508
509
            if len(ug.gp.rec) > 0:
                urec = group_vector_property(ug.gp.rec)
                udrec = group_vector_property(ug.gp.drec)
            else:
                urec = ug.new_ep("vector<double>")
                udrec = ug.new_ep("vector<double>")
510
511
            bg, props = graph_union(bg, ug,
                                    props=[(mrs, ug.ep["count"]),
512
513
514
                                           (ec, uec),
                                           (rec, urec),
                                           (drec, udrec)],
515
516
517
518
                                    intersection=ug.vertex_index,
                                    include=True)
            mrs = props[0]
            ec = props[1]
519
520
            rec = props[2]
            drec = props[3]
521

522
523
524
        rec = ungroup_vector_property(rec, range(len(self.rec)))
        drec = ungroup_vector_property(drec, range(len(self.drec)))

525
        return bg, mrs, ec, rec, drec
526
527

    def get_block_state(self, b=None, vweight=False, deg_corr=False,
528
                        overlap=False, layers=None, **kwargs):
529
530
531
532
        r"""Returns a :class:`~graph_tool.inference.layered_blockmodel.LayeredBlockState`
        corresponding to the block graph. The parameters have the same meaning
        as the in the constructor.
        """
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        copy_bg = kwargs.pop("copy_bg", True)

        if copy_bg:
            bg, mrs, ec, brec, bdrec = self.get_bg()
            gs = []
        else:
            gs = []
            for l, s in enumerate(self.layer_states):
                u = GraphView(s.bg)
                u.ep.weight = s.mrs
                u.vp.vmap = u.own_property(s.g.vp.brmap).copy()
                u.vp.b = u.new_vp("int")
                if vweight == True:
                    u.vp.weight = u.own_property(s.wr)
                else:
                    u.vp.weight = u.new_vp("int", s.wr.a > 0)
                u.vp.brmap = u.new_vp("int")
                u.gp.rec = u.new_gp("object", val=s.brec)
                u.gp.drec = u.new_gp("object", val=s.bdrec)
                gs.append(u)
            bg = self.agg_state.bg
            mrs = self.agg_state.mrs
            ec = None
            brec = self.brec
            bdrec = self.bdrec
559

560
        lweights = bg.new_vp("vector<int>")
561
        if not overlap and vweight == True:
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
            degs = libinference.get_layered_block_degs(self.g._Graph__graph,
                                                       _prop("e", self.g,
                                                             self.eweight),
                                                       _prop("v", self.g,
                                                             self.vweight),
                                                       _prop("e", self.g,
                                                             self.ec),
                                                       _prop("v", self.g,
                                                             self.b))
            libinference.get_blweights(self.g._Graph__graph,
                                       _prop("v", self.g, self.b),
                                       _prop("v", self.g, self.vc),
                                       _prop("v", self.g, self.vmap),
                                       _prop("v", bg, lweights),
                                       [_prop("v", state.g, state.vweight)
                                        for state in self.layer_states])
        else:
579
            degs = None
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        copy_coupled = False
        recs = False
        if vweight == "nonempty":
            vweight = bg.new_vp("int", self.wr.a > 0)
            layers = True if layers is None else layers
        elif vweight == "unity":
            vweight = bg.new_vp("int", 1)
            layers = True if layers is None else layers
        elif vweight == True:
            if copy_bg:
                vweight = bg.own_property(self.wr.copy())
            else:
                vweight = self.wr
            recs = True
            copy_coupled = True
            kwargs["Lrecdx"] = kwargs.get("Lrecdx",
                                          [x.copy() for x in self.Lrecdx])
        else:
            vweight = None
            layers = True if layers is None else layers

        if recs:
            rec_types = kwargs.pop("rec_types", self.rec_types)
            recs = kwargs.pop("recs", brec)
            drec = kwargs.pop("drec", bdrec)
            rec_params = kwargs.pop("rec_params", self.rec_params)
607
        else:
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
            recs = []
            drec = None
            for u in gs:
                u.gp.drec = None
                u.gp.rec = []
            rec_types = []
            rec_params = []
            for i, (rt, rp, r) in enumerate(zip(self.rec_types, self.wparams,
                                                brec)):
                if rt == libinference.rec_type.count:
                    recs.append(bg.new_ep("double", mrs.fa > 0))
                    for l, u in enumerate(gs):
                        u.gp.rec.append(u.new_ep("double", u.ep.weight.fa > 0))
                    rec_types.append(rt)
                    rec_params.append("microcanonical")
                elif numpy.isnan(rp.a).sum() == 0:
                    continue
                elif rt in [libinference.rec_type.discrete_geometric,
                            libinference.rec_type.discrete_binomial,
                            libinference.rec_type.discrete_poisson]:
                    recs.append(r)
                    for l, u in enumerate(gs):
                        u.gp.rec.append(self.layer_states[l].brec[i])
                    rec_types.append(libinference.rec_type.discrete_geometric)
                    rec_params.append("microcanonical")
                elif rt == libinference.rec_type.real_exponential:
                    recs.append(r)
                    for l, u in enumerate(gs):
                        u.gp.rec.append(self.layer_states[l].brec[i])
                    rec_types.append(rt)
                    rec_params.append("microcanonical")
                elif rt == libinference.rec_type.real_normal:
                    recs.append(r)
                    for l, u in enumerate(gs):
                        u.gp.rec.append(self.layer_states[l].brec[i])
                    rec_types.append(rt)
                    rec_params.append("microcanonical")
            rec_params = kwargs.pop("rec_params", rec_params)

647
        state = LayeredBlockState(bg, ec, eweight=mrs,
648
649
650
651
652
653
                                  vweight=vweight,
                                  gs=gs,
                                  rec_types=rec_types,
                                  recs=recs,
                                  drec=drec,
                                  rec_params=rec_params,
654
655
656
                                  b=bg.vertex_index.copy("int") if b is None else b,
                                  deg_corr=deg_corr,
                                  overlap=overlap,
657
                                  dense_bg=self.dense_bg,
658
659
660
                                  layers=self.layers if layers is None else layers,
                                  ec_done=True,
                                  degs=degs, lweights=lweights,
661
662
663
664
                                  clabel=kwargs.pop("clabel",
                                                    self.agg_state.get_bclabel()),
                                  pclabel=kwargs.pop("pclabel",
                                                     self.agg_state.get_bpclabel()),
665
666
                                  epsilon=kwargs.pop("epsilon",
                                                     self.epsilon.copy()),
667
                                  **kwargs)
668
669
670
671
672

        if copy_coupled and self._coupled_state is not None:
            state._couple_state(state.get_block_state(b=state.get_bclabel(),
                                                      copy_bg=False,
                                                      vweight="nonempty",
673
                                                      Lrecdx=state.Lrecdx),
674
675
                                self._coupled_state[1])

676
677
        return state

678
679
    def _set_bclabel(self, bstate):
        BlockState._set_bclabel(self, bstate)
680
        #self._state.sync_bclabel()
681
682
        # for s, sn in zip(self.layer_states, bstate.layer_states):
        #     s.bclabel.a = sn.b.a
683
684
685
686
687
688

    def get_edge_blocks(self):
        r"""Returns an edge property map which contains the block labels pairs for each
        edge."""
        if not self.overlap:
            raise ValueError("edge blocks only available if overlap == True")
689
        return self.agg_state.get_edge_blocks()
690
691
692
693
694
695

    def get_overlap_blocks(self):
        r"""Returns the mixed membership of each vertex.

        Returns
        -------
696
        bv : :class:`~graph_tool.VertexPropertyMap`
697
698
           A vector-valued vertex property map containing the block memberships
           of each node.
699
        bc_in : :class:`~graph_tool.VertexPropertyMap`
700
701
           The labelled in-degrees of each node, i.e. how many in-edges belong
           to each group, in the same order as the ``bv`` property above.
702
        bc_out : :class:`~graph_tool.VertexPropertyMap`
703
704
           The labelled out-degrees of each node, i.e. how many out-edges belong
           to each group, in the same order as the ``bv`` property above.
705
        bc_total : :class:`~graph_tool.VertexPropertyMap`
706
707
708
709
710
711
           The labelled total degrees of each node, i.e. how many incident edges
           belong to each group, in the same order as the ``bv`` property above.

        """
        if not self.overlap:
            raise ValueError("overlap blocks only available if overlap == True")
712
        return self.agg_state.get_overlap_blocks()
713
714
715
716
717
718
719
720

    def get_nonoverlap_blocks(self):
        r"""Returns a scalar-valued vertex property map with the block mixture
        represented as a single number."""

        if not self.overlap:
            return self.b.copy()
        else:
721
            return self.agg_state.get_nonoverlap_blocks()
722
723
724
725
726
727
728
729

    def get_majority_blocks(self):
        r"""Returns a scalar-valued vertex property map with the majority block
        membership of each node."""

        if not self.overlap:
            return self.b.copy()
        else:
730
            return self.agg_state.get_majority_blocks()
731

732
    def _couple_state(self, state, entropy_args):
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        if state is None:
            self._coupled_state = None
            self._state.decouple_state()
        else:
            if _bm_test():
                assert state.g.base is self.bg.base
                assert state.agg_state.g.base is self.agg_state.bg.base
                for l, (s1, s2) in enumerate(zip(state.layer_states,
                                                 self.layer_states)):
                    assert s1.g.base is s2.bg.base, (l, s1, s2)

            self._coupled_state = (state, entropy_args)
            eargs = get_entropy_args(dict(self._entropy_args,
                                          **entropy_args))
            self._state.couple_state(state._state, eargs)
748
            #self._set_bclabel(state)
749
750
751
752
753
754
755
756
757
758

    def _set_bclabel(self, bstate):
        BlockState._set_bclabel(self, bstate)
        for s, bs in zip(self.layer_states,
                         bstate.layer_states):
            s._set_bclabel(bs)

    def _check_clabel(self, clabel=None, b=None):
        if not BlockState._check_clabel(self, clabel, b):
            return False
759
760
761
762
763
764
765
        # if self._coupled_state is not None:
        #     for s, bs in zip(self.layer_states,
        #                      self._coupled_state[0].layer_states):
        #         b = s.bclabel
        #         mask = bs.vweight.fa > 0
        #         if any(b.fa[mask] != bs.b.fa[mask]):
        #             return False
766
        return True
767

768
769
770
    def entropy(self, adjacency=True, dl=True, partition_dl=True,
                degree_dl=True, degree_dl_kind="distributed", edges_dl=True,
                dense=False, multigraph=True, deg_entropy=True, exact=True,
771
                **kwargs):
772
773
        r"""Calculate the entropy associated with the current block partition. The
        meaning of the parameters are the same as in
774
        :meth:`graph_tool.inference.blockmodel.BlockState.entropy`.
775
776
777
778
779
780
781
782
        """

        if _bm_test() and kwargs.get("test", True):
            args = dict(**locals())
            args.update(**kwargs)
            del args["self"]
            del args["kwargs"]

783
784
        S = BlockState.entropy(self, adjacency=adjacency, dl=dl,
                               partition_dl=partition_dl, degree_dl=degree_dl,
785
                               degree_dl_kind=degree_dl_kind, edges_dl=edges_dl,
786
                               dense=dense, multigraph=multigraph,
787
                               deg_entropy=deg_entropy, exact=exact,
788
                               **dict(kwargs, test=False))
789
790
791
792
793

        if _bm_test() and kwargs.get("test", True):
            assert not isnan(S) and not isinf(S), \
                "invalid entropy %g (%s) " % (S, str(args))

794
795
            state = self.copy()
            Salt = state.entropy(test=False, **args)
796
            assert math.isclose(S, Salt, abs_tol=1e-8), \
797
798
799
800
                "entropy discrepancy after copying (%g %g)" % (S, Salt)

        return S

801
802
803
804
805
806
807
    def _get_lvertex(self, v, l):
        i = numpy.searchsorted(self.vc[v].a, l)
        if i >= len(self.vc[v]) or l != self.vc[v][i]:
            raise ValueError("vertex %d not present in layer %d" % (v, l))
        u = self.vmap[v][i]
        return u

808
    def get_edges_prob(self, missing, spurious=[], entropy_args={}):
809
        r"""Compute the joint log-probability of the missing and spurious edges given by
810
811
812
813
814
815
816
817
        ``missing`` and ``spurious`` (a list of ``(source, target, layer)``
        tuples, or :meth:`~graph_tool.Edge` instances), together with the
        observed edges.

        More precisely, the log-likelihood returned is

        .. math::

818
            \ln \frac{P(\boldsymbol G + \delta \boldsymbol G | \boldsymbol b)}{P(\boldsymbol G| \boldsymbol b)}
819
820
821
822
823

        where :math:`\boldsymbol G + \delta \boldsymbol G` is the modified graph
        (with missing edges added and spurious edges deleted).

        The values in ``entropy_args`` are passed to
824
        :meth:`graph_tool.inference.blockmodel.BlockState.entropy()` to calculate the
825
826
        log-probability.
        """
827
828
829

        Si = self.entropy(**dict(dict(partition_dl=False), **entropy_args))

830
831
        pos = {}
        nes = []
832
        for e in itertools.chain(missing, spurious):
833
834
835
            try:
                u, v = e
                l = self.ec[e]
836
            except (TypeError, ValueError):
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
                u, v, l = e

            pos[u] = self.b[u]
            pos[v] = self.b[v]

            nes.append((u, v, (l, False)))
            nes.append((self._get_lvertex(u, l),
                        self._get_lvertex(v, l), (l, True)))

        edge_list = nes

        self.remove_vertex(pos.keys())

        agg_state = self.agg_state

        try:
853
            new_es = []
854
855
            for i in range(len(missing)):
                u, v, l = edge_list[i]
856
857
858
859
860
                if not l[1]:
                    state = self.agg_state
                else:
                    state = self.layer_states[l[0]]
                e = state.g.add_edge(u, v)
861
                if not l[1] and self.ec is not None:
862
863
864
865
866
867
                    self.ec[e] = l[0]
                if state.is_weighted:
                    state.eweight[e] = 1
                new_es.append((e, l))

            old_es = []
868
869
            for i in range(len(spurious)):
                u, v, l = edge_list[i + len(missing)]
870
871
872
                if not l[1]:
                    state = self.agg_state
                    es = state.g.edge(u, v, all_edges=True)
873
874
875
876
                    if self.ec is not None:
                        es = [e for e in es if self.ec[e] == l[0]]
                    else:
                        es = list(es)
877
878
                    if len(es) > 0:
                        e = es[0]
879
                    else:
880
881
882
883
884
885
886
887
888
889
890
                        e = None
                else:
                    state = self.layer_states[l[0]]
                    e = state.g.edge(u, v)
                if e is None:
                    raise ValueError("edge not found: (%d, %d, %d)" % \
                                     (int(u), int(v), l[0]))

                if state.is_weighted:
                    staete.eweight[e] -= 1
                    if state.eweight[e] == 0:
891
                        state.g.remove_edge(e)
892
893
894
                else:
                    state.g.remove_edge(e)
                old_es.append((u, v, l))
895
896
897

            self.add_vertex(pos.keys(), pos.values())

898
            Sf = self.entropy(**dict(dict(partition_dl=False), **entropy_args))
899
900
901
902

            self.remove_vertex(pos.keys())

        finally:
903
904
905
906
907
908
909
910
911
912
913
914
915
916
            for e, l in new_es:
                if not l[1]:
                    state = self.agg_state
                else:
                    state = self.layer_states[l[0]]
                state.g.remove_edge(e)
            for u, v, l in old_es:
                if not l[1]:
                    state = self.agg_state
                else:
                    state = self.layer_states[l[0]]
                if state.is_weighted:
                    e = state.g.edge(u, v)
                    if e is None:
917
                        e = state.g.add_edge(u, v)
918
                        state.eweight[e] = 0
919
                        if not l[1] and self.ec is not None:
920
                            self.ec[e] = l[0]
921
922
923
                    state.eweight[e] += 1
                else:
                    e = state.g.add_edge(u, v)
924
                    if not l[1] and self.ec is not None:
925
                        self.ec[e] = l[0]
926
927
            self.add_vertex(pos.keys(), pos.values())

928
        L = Si - Sf
929
930
931
932
933
934
935

        if _bm_test():
            state = self.copy()
            set_test(False)
            L_alt = state.get_edges_prob(edge_list, missing=missing,
                                         entropy_args=entropy_args)
            set_test(True)
936
            assert math.isclose(L, L_alt, abs_tol=1e-8), \
937
938
939
940
                "inconsistent missing=%s edge probability (%g, %g): %s, %s" % \
                (str(missing), L, L_alt,  str(entropy_args), str(edge_list))

        return L
941

942
943
944
    def _clear_egroups(self):
        self._state.clear_egroups()

945
946
947
948
949
    def _mcmc_sweep_dispatch(self, mcmc_state):
        if not self.overlap:
            return libinference.mcmc_layered_sweep(mcmc_state, self._state,
                                                   _get_rng())
        else:
950
951
952
953
            dS, nattempts, nmoves = \
                    libinference.mcmc_layered_overlap_sweep(mcmc_state,
                                                            self._state,
                                                            _get_rng())
954
955
956
957
958
            if self.__bundled:
                ret = libinference.mcmc_layered_overlap_bundled_sweep(mcmc_state,
                                                                      self._state,
                                                                      _get_rng())
                dS += ret[0]
959
960
961
                nattempts += ret[1]
                nmoves += ret[2]
            return dS, nattempts, nmoves
962

963
964
965
966
967
968
969
970
971
    def _mcmc_sweep_parallel_dispatch(states, mcmc_states):
        if not states[0].overlap:
            return libinference.mcmc_layered_sweep_parallel(mcmc_states,
                                                            [s._state for s in states],
                                                            _get_rng())
        else:
            return libinference.mcmc_layered_overlap_sweep_parallel(mcmc_states,
                                                                    [s._state for s in states],
                                                                    _get_rng())
972
973
974
975
976
    def mcmc_sweep(self, bundled=False, **kwargs):
        r"""Perform sweeps of a Metropolis-Hastings rejection sampling MCMC to sample
        network partitions. If ``bundled == True`` and the state is an
        overlapping one, the half-edges incident of the same node that belong to
        the same group are moved together. All remaining parameters are passed
977
        to :meth:`graph_tool.inference.blockmodel.BlockState.mcmc_sweep`."""
978
979
980
981

        self.__bundled = bundled
        return BlockState.mcmc_sweep(self, **kwargs)

982
983
984
985
986
987
988
989
990
991
    def _multiflip_mcmc_sweep_dispatch(self, mcmc_state):
        if not self.overlap:
            return libinference.multiflip_mcmc_layered_sweep(mcmc_state,
                                                             self._state,
                                                             _get_rng())
        else:
            return libinference.multiflip_mcmc_layered_overlap_sweep(mcmc_state,
                                                                     self._state,
                                                                     _get_rng())

992
993
994
995
996
997
998
999
1000
1001
    def _multiflip_mcmc_sweep_parallel_dispatch(states, mcmc_states):
        if not states[0].overlap:
            return libinference.multiflip_mcmc_layered_sweep_parallel(mcmc_states,
                                                                      [s._state for s in states],
                                                                      _get_rng())
        else:
            return libinference.multiflip_mcmc_layered_overlap_sweep_parallel(mcmc_states,
                                                                              [s._state for s in states],
                                                                              _get_rng())

1002
1003
1004
1005
1006
1007
1008
1009
1010
    def _gibbs_sweep_dispatch(self, mcmc_state):
        if not self.overlap:
            return libinference.gibbs_layered_sweep(mcmc_state, self._state,
                                                    _get_rng())
        else:
            return libinference.gibbs_layered_overlap_sweep(mcmc_state,
                                                            self._state,
                                                            _get_rng())

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    def _gibbs_sweep_parallel_dispatch(states, gibbs_states):
        if not states[0].overlap:
            return libinference.gibbs_layered_sweep_parallel(gibbs_states,
                                                             [s._state for s in states],
                                                             _get_rng())
        else:
            return libinference.gibbs_layered_overlap_sweep_parallel(gibbs_states,
                                                                     [s._state for s in states],
                                                                     _get_rng())

1021
1022
    def _multicanonical_sweep_dispatch(self, mcmc_state):
        if not self.overlap:
1023
1024
1025
1026
1027
1028
1029
1030
            if mcmc_state.multiflip:
                return libinference.multicanonical_layered_multiflip_sweep(mcmc_state,
                                                                           self._state,
                                                                           _get_rng())
            else:
                return libinference.multicanonical_layered_sweep(mcmc_state,
                                                                 self._state,
                                                                 _get_rng())
1031
        else:
1032
1033
1034
1035
1036
1037
1038
1039
            if mcmc_state.multiflip:
                return libinference.multicanonical_layered_overlap_multiflip_sweep(mcmc_state,
                                                                                   self._state,
                                                                                   _get_rng())
            else:
                return libinference.multicanonical_layered_overlap_sweep(mcmc_state,
                                                                         self._state,
                                                                         _get_rng())
1040

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    def _exhaustive_sweep_dispatch(self, exhaustive_state, callback, hist):
        if not self.overlap:
            if callback is not None:
                return libinference.exhaustive_layered_sweep(exhaustive_state,
                                                             self._state,
                                                             callback)
            else:
                if hist is None:
                    return libinference.exhaustive_layered_sweep_iter(exhaustive_state,
                                                                      self._state)
                else:
                    return libinference.exhaustive_layered_sweep_dens(exhaustive_state,
                                                                      self._state,
                                                                      hist[0],
                                                                      hist[1],
                                                                      hist[2])
        else:
            if callback is not None:
                return libinference.exhaustive_layered_overlap_sweep(exhaustive_state,
                                                                     self._state,
                                                                     callback)
            else:
                if hist is None:
                    return libinference.exhaustive_layered_overlap_sweep_iter(exhaustive_state,
                                                                              self._state)
                else:
                    return libinference.exhaustive_layered_overlap_dens(exhaustive_state,
                                                                        self._state,
                                                                        hist[0],
                                                                        hist[1],
                                                                        hist[2])
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    def _merge_sweep_dispatch(self, merge_state):
        if not self.overlap:
            return libinference.merge_layered_sweep(merge_state, self._state,
                                                    _get_rng())
        else:
            return libinference.vacate_layered_overlap_sweep(merge_state,
                                                             self._state,
                                                             _get_rng())

    def shrink(self, B, **kwargs):
        """Reduces the order of current state by progressively merging groups, until
        only ``B`` are left. All remaining keyword arguments are passed to
1084
1085
        :meth:`graph_tool.inference.blockmodel.BlockState.shrink` or
        :meth:`graph_tool.inference.overlap_blockmodel.OverlapBlockState.shrink`, as appropriate.
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

        This function leaves the current state untouched and returns instead a
        copy with the new partition.
        """

        if not self.overlap:
            return BlockState.shrink(self, B, **kwargs)
        else:
            return OverlapBlockState.shrink(self, B, **kwargs)

    def draw(self, **kwargs):
        """Convenience function to draw the current state. All keyword arguments are
1098
1099
        passed to :meth:`graph_tool.inference.blockmodel.BlockState.draw` or
        :meth:`graph_tool.inference.overlap_blockmodel.OverlapBlockState.draw`, as appropriate.
1100
1101
        """

1102
        self.agg_state.draw(**kwargs)
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112


def init_layer_confined(g, ec):
    tmp_state = CovariateBlockState(g, ec=ec, B=g.num_vertices())
    tmp_state = tmp_state.copy(overlap=True)
    be = tmp_state.get_edge_blocks()
    ba = ungroup_vector_property(be, [0])[0]
    ba.fa = ba.fa + ec.fa * (ba.fa.max() + 1)
    continuous_map(ba)
    be = group_vector_property([ba, ba])
Alex Henrie's avatar
Alex Henrie committed
1113
    return be