Commit d4eae9c5 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix edge covariate issues with layered SBMs

parent bf5e1f9f
......@@ -98,7 +98,7 @@ void split_layers(GraphInterface& gi, boost::any& aec, boost::any& ab,
typedef vprop_map_t<int32_t>::type vmap_t;
typedef vprop_map_t<vector<int32_t>>::type vvmap_t;
typedef eprop_map_t<int32_t>::type emap_t;
typedef eprop_map_t<double>::type remap_t;
typedef eprop_map_t<std::vector<double>>::type remap_t;
emap_t& ec = any_cast<emap_t&>(aec);
vmap_t& b = any_cast<vmap_t&>(ab);
......
......@@ -421,8 +421,7 @@ class BlockState(object):
self.brecsum = self.bg.new_vp("double")
self.bignore_degrees = self.bg.new_vp("bool")
self.rec_params = rec_params
rec_params = list(rec_params)
self.rec_params = rec_params = list(rec_params)
while len(rec_params) < len(self.rec_types):
rec_params.append({})
self.wparams = libcore.Vector_Vector_double()
......@@ -452,6 +451,7 @@ class BlockState(object):
ks = list(defaults.keys())
defaults.update(rec_params[i])
rec_params[i] = dict(**defaults)
for k in ks:
ps.append(defaults.pop(k))
if len(defaults) > 0:
......@@ -500,6 +500,8 @@ class BlockState(object):
r"""Copies the block state. The parameters override the state properties, and
have the same meaning as in the constructor."""
recs = ungroup_vector_property(self.rec, range(len(self.recs)))
if not overlap:
state = BlockState(self.g if g is None else g,
eweight=self.eweight if eweight is None else eweight,
......@@ -513,7 +515,7 @@ class BlockState(object):
degs=self.degs.copy(),
merge_map=kwargs.pop("merge_map",
self.merge_map.copy()),
recs=kwargs.pop("recs", self.recs),
recs=kwargs.pop("recs", recs),
drec=kwargs.pop("drec", self.drec),
rec_types=kwargs.pop("rec_types", self.rec_types),
rec_params=kwargs.pop("rec_params",
......@@ -527,9 +529,10 @@ class BlockState(object):
state = OverlapBlockState(self.g if g is None else g,
b=self.b.copy() if b is None else b,
B=(self.B if b is None else None) if B is None else B,
recs=kwargs.pop("recs", self.recs),
recs=kwargs.pop("recs", recs),
drec=kwargs.pop("drec", self.drec),
rec_types=kwargs.pop("rec_types", self.rec_types),
rec_types=kwargs.pop("rec_types",
self.rec_types),
rec_params=kwargs.pop("rec_params",
self.rec_params),
clabel=self.clabel if clabel is None else clabel,
......@@ -553,7 +556,8 @@ class BlockState(object):
deg_corr=self.deg_corr,
allow_empty=self.allow_empty,
max_BE=self.max_BE,
recs=self.recs,
recs=ungroup_vector_property(self.rec,
range(len(self.recs))),
drec=self.drec,
rec_types=list(self.rec_types),
rec_params=self.rec_params,
......
......@@ -59,15 +59,16 @@ class LayeredBlockState(OverlapBlockState, BlockState):
ec : :class:`~graph_tool.PropertyMap` Edge :class:`~graph_tool.PropertyMap`
containing discrete edge covariates that will split the network in
discrete layers.
rec : :class:`~graph_tool.PropertyMap` (optional, default: ``None``)
Real-valued edge covariates.
rec_type : `"positive"`, `"signed"` or `None` (optional, default: ``None``)
Type of edge covariates. If not specified, it will be guessed from
``rec``.
rec_params : ``dict`` (optional, default: ``{}``)
Model hyperparameters for real-valued covariates. This should be a
``dict`` with keys in the list ``["alpha", "beta"]`` if ``rec_type ==
positive`` or ``["m0", "k0", "v0". "nu0"]`` if ``rec_type == signed``.
recs : list of :class:`~graph_tool.PropertyMap` instances (optional, default: ``[]``)
List of real or discrete-valued edge covariates.
rec_types : list of edge covariate types (optional, default: ``[]``)
List of types of edge covariates. The possible types are:
``"real-exponential"``, ``"real-normal"``, ``"discrete-geometric"``,
``"discrete-poisson"`` or ``"discrete-binomial"``.
rec_params : list of ``dict`` (optional, default: ``[]``)
Model hyperparameters for edge covariates. This should a list of
``dict`` instances. See :class:`~graph_tool.inference.BlockState` for
more details.
eweight : :class:`~graph_tool.PropertyMap` (optional, default: ``None``)
Edge multiplicities (for multigraphs or block graphs).
vweight : :class:`~graph_tool.PropertyMap` (optional, default: ``None``)
......@@ -101,10 +102,10 @@ class LayeredBlockState(OverlapBlockState, BlockState):
the block graph. Otherwise a dense matrix will be used.
"""
def __init__(self, g, ec, eweight=None, vweight=None, rec=None,
rec_type=None, rec_params={}, b=None, B=None, clabel=None,
pclabel=False, layers=False, deg_corr=True, overlap=False,
allow_empty=False, **kwargs):
def __init__(self, g, ec, eweight=None, vweight=None, recs=[], rec_types=[],
rec_params=[], b=None, B=None, clabel=None, pclabel=False,
layers=False, deg_corr=True, overlap=False, allow_empty=False,
**kwargs):
kwargs = kwargs.copy()
......@@ -141,8 +142,8 @@ class LayeredBlockState(OverlapBlockState, BlockState):
tdegs = libinference.simple_degs_t()
agg_state = BlockState(GraphView(g, skip_properties=True), b=b, B=B,
eweight=eweight, vweight=vweight, rec=rec,
rec_type=rec_type, rec_params=rec_params,
eweight=eweight, vweight=vweight, recs=recs,
rec_types=rec_types, rec_params=rec_params,
clabel=clabel, pclabel=pclabel,
deg_corr=deg_corr, allow_empty=allow_empty,
max_BE=max_BE, degs=tdegs,
......@@ -151,8 +152,8 @@ class LayeredBlockState(OverlapBlockState, BlockState):
else:
kwargs = dmask(kwargs, ["degs"])
ldegs = None
agg_state = OverlapBlockState(g, b=b, B=B, rec=rec,
rec_type=rec_type,
agg_state = OverlapBlockState(g, b=b, B=B, recs=recs,
rec_types=rec_types,
rec_params=rec_params, clabel=clabel,
pclabel=pclabel, deg_corr=deg_corr,
allow_empty=allow_empty, max_BE=max_BE,
......@@ -182,9 +183,10 @@ class LayeredBlockState(OverlapBlockState, BlockState):
self.allow_empty = agg_state.allow_empty
self.recs = agg_state.recs
self.rec = agg_state.rec
self.drec = agg_state.drec
self.rec_type = agg_state.rec_type
self.rec_types = agg_state.rec_types
self.rec_params = agg_state.rec_params
self.b = agg_state.b
......@@ -209,8 +211,8 @@ class LayeredBlockState(OverlapBlockState, BlockState):
u.vp["b"] = u.new_vp("int")
u.vp["weight"] = u.new_vp("int")
u.ep["weight"] = u.new_ep("int")
u.ep["rec"] = u.new_ep("double")
u.ep["drec"] = u.new_ep("double")
u.ep["rec"] = u.new_ep("vector<double>")
u.ep["drec"] = u.new_ep("vector<double>")
u.vp["brmap"] = u.new_vp("int")
u.vp["vmap"] = u.new_vp("int")
self.gs.append(u)
......@@ -279,7 +281,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
if _bm_test():
assert self.mrs.fa.sum() == self.eweight.fa.sum(), "inconsistent mrs!"
kwargs.pop("rec", None)
kwargs.pop("recs", None)
kwargs.pop("drec", None)
kwargs.pop("rec_params", None)
......@@ -326,11 +328,12 @@ class LayeredBlockState(OverlapBlockState, BlockState):
u.vp.vmap))
else:
degs = libinference.simple_degs_t()
recs = ungroup_vector_property(u.ep["rec"], range(len(self.recs)))
state = BlockState(u, b=u.vp["b"],
B=B,
rec=u.ep["rec"],
recs=recs,
drec=u.ep["drec"],
rec_type=self.rec_type,
rec_types=self.rec_types,
rec_params=self.rec_params,
eweight=u.ep["weight"],
vweight=u.vp["weight"],
......@@ -341,9 +344,9 @@ class LayeredBlockState(OverlapBlockState, BlockState):
base_u, node_index = self.__get_base_u(u)
state = OverlapBlockState(u, b=u.vp["b"].fa,
B=B,
rec=u.ep["rec"],
recs=recs,
drec=u.ep["drec"],
rec_type=self.rec_type,
rec_types=self.rec_types,
rec_params=self.rec_params,
node_index=node_index,
base_g=base_u,
......@@ -357,9 +360,10 @@ class LayeredBlockState(OverlapBlockState, BlockState):
def __getstate__(self):
state = dict(g=self.g,
ec=self.ec,
rec=self.rec if self.rec_type != libinference.rec_type.none else None,
drec=self.drec if self.rec_type == libinference.rec_type.signed else None,
rec_type=int(self.rec_type),
recs=ungroup_vector_property(self.rec,
range(len(self.recs))),
drec=self.drec,
rec_types=list(self.rec_types),
rec_params=self.rec_params,
layers=self.layers,
eweight=self.eweight,
......@@ -407,13 +411,15 @@ class LayeredBlockState(OverlapBlockState, BlockState):
else:
degs = libinference.simple_degs_t()
recs = ungroup_vector_property(self.rec, range(len(self.recs)))
state = LayeredBlockState(self.g if g is None else g,
ec=self.ec if ec is None else ec,
eweight=self.eweight if eweight is None else eweight,
vweight=self.vweight if vweight is None else vweight,
rec=kwargs.get("rec", self.rec),
recs=kwargs.get("recs", recs),
drec=kwargs.get("drec", self.drec),
rec_type=kwargs.get("rec_type", self.rec_type),
rec_types=kwargs.get("rec_types", self.rec_types),
rec_params=kwargs.get("rec_params", self.rec_params),
b=self.b if b is None else b,
B=(self.B if b is None else None) if B is None else B,
......@@ -431,7 +437,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
ec_done=ec is None,
degs=degs, lweights=lweights,
layer_entropy=self.__get_layer_entropy(),
**dmask(kwargs, ["rec", "rec_type", "drec",
**dmask(kwargs, ["recs", "rec_types", "drec",
"rec_params", "allow_empty"]))
return state
......@@ -440,10 +446,9 @@ class LayeredBlockState(OverlapBlockState, BlockState):
(self.B, "overlapping " if self.overlap else "",
self.C, "layers" if self.layers else "edge covariates",
" degree-corrected," if self.deg_corr else "",
((" with %s real-typed edge covariates," %
("positive" if self.rec_type == libinference.rec_type.positive
else "signed"))
if self.rec_type != libinference.rec_type.none else ""),
((" with %d edge covariate%s," % (len(self.rec_types),
"s" if len(self.rec_types) > 1 else ""))
if len(self.rec_types) > 0 else ""),
str(self.base_g if self.overlap else self.g), id(self))
def get_bg(self):
......@@ -452,8 +457,8 @@ class LayeredBlockState(OverlapBlockState, BlockState):
bg = Graph(directed=self.g.is_directed())
mrs = bg.new_edge_property("int")
ec = bg.new_edge_property("int")
rec = bg.new_edge_property("double")
drec = bg.new_edge_property("double")
rec = bg.new_edge_property("vector<double>")
drec = bg.new_edge_property("vector<double>")
for l in range(self.C):
u = GraphView(self.g, efilt=self.ec.a == l)
......@@ -484,6 +489,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
constructor."""
bg, mrs, ec, rec, drec = self.get_bg()
recs = ungroup_vector_property(rec, range(len(self.recs)))
lweights = bg.new_vp("vector<int>")
if not overlap and deg_corr and vweight:
degs = libinference.get_layered_block_degs(self.g._Graph__graph,
......@@ -509,11 +515,10 @@ class LayeredBlockState(OverlapBlockState, BlockState):
layer_entropy = self.__get_layer_entropy()
else:
layer_entropy = None
state = LayeredBlockState(bg, ec, eweight=mrs,
vweight=bg.own_property(self.wr.copy()) if vweight else None,
rec_type=kwargs.pop("rec_type", self.rec_type if vweight else None),
rec=kwargs.pop("rec", rec if vweight else None),
rec_types=kwargs.pop("rec_types", self.rec_types if vweight else None),
recs=kwargs.pop("recs", recs if vweight else None),
drec=kwargs.pop("drec", drec if vweight else None),
rec_params=kwargs.pop("rec_params", self.rec_params),
b=bg.vertex_index.copy("int") if b is None else b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment