diff --git a/src/graph_tool/inference/blockmodel.py b/src/graph_tool/inference/blockmodel.py index ec901bf3872e6293d4c14559109d2f7813c31f81..930bba79132b0ff51c9910557b32553de0181e0e 100644 --- a/src/graph_tool/inference/blockmodel.py +++ b/src/graph_tool/inference/blockmodel.py @@ -2266,7 +2266,7 @@ class BlockState(object): "edge_gradient"])) def sample_graph(self, canonical=False, multigraph=True, self_loops=True, - max_ent=False, n_iter=1000): + sample_params=False, max_ent=False, n_iter=1000): r"""Sample a new graph from the fitted model. Parameters @@ -2279,6 +2279,11 @@ class BlockState(object): If ``True``, parallel edges will be allowed. self-loops : ``bool`` (optional, default: ``True``) If ``True``, self-loops will be allowed. + sample_params : ``bool`` (optional, default: ``True``) + If ``True``, and ``canonical == False`` and ``max_ent == False``, + the count parameters (edges between groups and node degrees) will be + sampled from their posterior distribution conditioned on the actual + state. Otherwise, their maximum-likelihood values will be used. max_ent : ``bool`` (optional, default: ``False``) If ``True``, maximum-entropy model variants will be used. n_iter : ``int`` (optional, default: ``1000``) @@ -2332,11 +2337,31 @@ class BlockState(object): in_degs = None probs = adjacency(self.bg, weight=self.mrs).T if not max_ent: + if canonical and sample_params: + rs = self.wr.a > 0 + B = rs.sum() + if self.g.is_directed(): + p = self.g.num_edges() / B ** 2 + if not self.g.is_directed(): + p = 2 * self.g.num_edges() / ((B + 1) * B) + idx = probs.nonzero() + probs[idx] = numpy.random.gamma(probs[idx] + 1, p/(p + 1)) + for r in rs: + idx = self.b.fa == r + er = probs[r,:].sum() + out_degs[idx] = numpy.random.dirichlet(out_degs[idx] + 1) + out_degs[idx] = numpy.random.multinomial(int(er), out_degs[idx]) + if in_degs is not None: + er = probs[:,r].sum() + in_degs[idx] = numpy.random.dirichlet(in_degs[idx] + 1) + in_degs[idx] = numpy.random.multinomial(int(er), in_degs[idx]) + g = generate_sbm(b=self.b.fa, probs=probs, in_degs=in_degs, out_degs=out_degs, directed=self.g.is_directed(), micro_ers=not canonical, micro_degs=not canonical and self.deg_corr) + if not multigraph: remove_parallel_edges(g) if not self_loops: