Commit 8037d8d5 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: implement stricter checking for unexpected keyword arguments

parent 8eef438b
......@@ -140,13 +140,13 @@ for directed in [True, False]:
print("\ndirected:", directed, "overlap:", overlap,
"layered:", layered, "deg-corr:", deg_corr, file=out)
state_args = dict(ec=ec, layers=(layered == True)) if layered != False else {}
state = minimize_blockmodel_dl(GraphView(g, directed=directed),
verbose=(1, "\t") if verbose else False,
deg_corr=deg_corr,
overlap=overlap,
layers=layered != False,
state_args=dict(ec=ec,
layers=(layered == True)))
state_args=state_args)
print(state.B, state.entropy(), file=out)
state = minimize_nested_blockmodel_dl(GraphView(g, directed=directed),
......@@ -154,8 +154,7 @@ for directed in [True, False]:
deg_corr=deg_corr,
overlap=overlap,
layers=layered != False,
state_args=dict(ec=ec,
layers=(layered == True)))
state_args=state_args)
if verbose:
state.print_summary()
print(state.entropy(), file=out)
......
......@@ -100,6 +100,7 @@ class BlockState(object):
def __init__(self, g, eweight=None, vweight=None, b=None, B=None,
clabel=None, pclabel=None, deg_corr=True, max_BE=1000,
**kwargs):
kwargs = kwargs.copy()
# initialize weights to unity, if necessary
if eweight is None or isinstance(eweight, libinference.unity_eprop_t):
......@@ -125,7 +126,7 @@ class BlockState(object):
self.deg_corr = deg_corr
self.overlap = False
self.degs = kwargs.get("degs", libinference.simple_degs_t())
self.degs = extract_arg(kwargs, "degs", libinference.simple_degs_t())
if self.degs is None:
self.degs = libinference.simple_degs_t()
......@@ -210,11 +211,12 @@ class BlockState(object):
else:
self.use_hash = libinference.false_type()
self.ignore_degrees = kwargs.get("ignore_degrees", None)
self.ignore_degrees = extract_arg(kwargs, "ignore_degrees", None)
if self.ignore_degrees is None:
self.ignore_degrees = g.new_vp("bool", False)
self.merge_map = kwargs.get("merge_map", self.g.vertex_index.copy("int"))
self.merge_map = extract_arg(kwargs, "merge_map",
self.g.vertex_index.copy("int"))
self.block_list = Vector_size_t()
self.block_list.extend(arange(self.B, dtype="int"))
......@@ -222,6 +224,10 @@ class BlockState(object):
self._abg = self.bg._get_any()
self._state = libinference.make_block_state(self, _get_rng())
if len(kwargs) > 0:
raise ValueError("unrecognized keyword arguments: " +
str(list(kwargs.keys())))
def __repr__(self):
return "<BlockState object with %d blocks,%s for graph %s, at 0x%x>" % \
......@@ -498,8 +504,10 @@ class BlockState(object):
del args["self"]
del args["kwargs"]
xi_fast = kwargs.get("xi_fast", False)
dl_deg_alt = kwargs.get("dl_deg_alt", True)
kwargs = kwargs.copy()
xi_fast = extract_arg(kwargs, "xi_fast", False)
dl_deg_alt = extract_arg(kwargs, "dl_deg_alt", True)
E = self.E
N = self.N
......@@ -525,11 +533,11 @@ class BlockState(object):
S += S_seq
callback = kwargs.get("callback", None)
callback = extract_arg(kwargs, "callback", None)
if callback is not None:
S += callback(self)
if _bm_test() and kwargs.get("test", True):
if _bm_test() and extract_arg(kwargs, "test", True):
assert not isnan(S) and not isinf(S), \
"invalid entropy %g (%s) " % (S, str(args))
......@@ -538,6 +546,9 @@ class BlockState(object):
assert abs(S - Salt) < 1e-6, \
"entropy discrepancy after copying (%g %g)" % (S, Salt)
if len(kwargs) > 0:
raise ValueError("unrecognized keyword arguments: " +
str(list(kwargs.keys())))
return S
def get_matrix(self):
......
......@@ -91,9 +91,12 @@ class LayeredBlockState(OverlapBlockState, BlockState):
def __init__(self, g, ec, eweight=None, vweight=None, b=None, B=None,
clabel=None, pclabel=False, layers=False, deg_corr=True,
overlap=False, **kwargs):
kwargs = kwargs.copy()
self.g = g
if kwargs.get("ec_done", False):
if extract_arg(kwargs, "ec_done", False):
self.ec = ec
else:
self.ec = ec = perfect_prop_hash([ec], "int32_t")[0]
......@@ -112,7 +115,9 @@ class LayeredBlockState(OverlapBlockState, BlockState):
eweight = g.new_ep("int", 1)
if not overlap:
ldegs = kwargs.get("degs", libinference.simple_degs_t())
kwargs = dmask(kwargs, ["base_g", "node_index", "eindex",
"half_edges"])
ldegs = extract_arg(kwargs, "degs", libinference.simple_degs_t())
if not isinstance(ldegs, libinference.simple_degs_t):
tdegs = libinference.get_mapped_block_degs(self.g._Graph__graph,
ldegs, 0,
......@@ -125,15 +130,21 @@ class LayeredBlockState(OverlapBlockState, BlockState):
B=B, eweight=eweight, vweight=vweight,
clabel=clabel, pclabel=pclabel,
deg_corr=deg_corr, max_BE=max_BE,
degs=tdegs, **dmask(kwargs, ["degs"]))
degs=tdegs,
**dmask(kwargs, ["degs", "lweights",
"layer_entropy"]))
else:
kwargs = dmask(kwargs, ["degs"])
ldegs = None
total_state = OverlapBlockState(g, b=b, B=B, eweight=eweight,
vweight=vweight, clabel=clabel,
total_state = OverlapBlockState(g, b=b, B=B, clabel=clabel,
pclabel=pclabel, deg_corr=deg_corr,
max_BE=max_BE, **kwargs)
max_BE=max_BE,
**dmask(kwargs, ["degs", "lweights",
"layer_entropy"]))
self.base_g = total_state.base_g
self.g = total_state.g
kwargs = dmask(kwargs, ["base_g", "node_index", "eindex",
"half_edges"])
self.total_state = total_state
......@@ -173,7 +184,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
self.gs = []
self.block_map = libinference.bmap_t()
lweights = kwargs.get("lweights", g.new_vp("vector<int>"))
lweights = extract_arg(kwargs, "lweights", g.new_vp("vector<int>"))
for l in range(0, self.C):
u = Graph(directed=g.is_directed())
......@@ -226,7 +237,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
self.block_list = Vector_size_t()
self.block_list.extend(arange(total_state.B, dtype="int"))
self.__layer_entropy = kwargs.get("layer_entropy", None)
self.__layer_entropy = extract_arg(kwargs, "layer_entropy", None)
if not self.overlap:
self._state = \
......@@ -241,6 +252,10 @@ class LayeredBlockState(OverlapBlockState, BlockState):
if _bm_test():
assert self.mrs.fa.sum() == self.eweight.fa.sum(), "inconsistent mrs!"
if len(kwargs) > 0:
raise ValueError("unrecognized keyword arguments: " +
str(list(kwargs.keys())))
def __get_base_u(self, u):
node_index = u.vp["vmap"].copy("int64_t")
pmap(node_index, self.total_state.node_index)
......@@ -269,14 +284,12 @@ class LayeredBlockState(OverlapBlockState, BlockState):
eweight=u.ep["weight"],
vweight=u.vp["weight"],
deg_corr=self.deg_corr,
force_weighted=self.is_weighted,
degs=degs,
max_BE=self.max_BE)
else:
base_u, node_index = self.__get_base_u(u)
state = OverlapBlockState(u, b=u.vp["b"].fa,
B=B,
vweight=u.vp["weight"],
node_index=node_index,
base_g=base_u,
deg_corr=self.deg_corr,
......
......@@ -71,6 +71,12 @@ class NestedBlockState(object):
if _bm_test():
self._consistency_check()
def _regen_levels(self):
for l in range(1, len(self.levels)):
state = self.levels[l]
nstate = self.levels[l-1].get_block_state(b=state.b, deg_corr=False)
self.levels[l] = nstate
def __repr__(self):
return "<NestedBlockState object, with base %s, and %d levels of sizes %s at 0x%x>" % \
(repr(self.levels[0]), len(self.levels),
......@@ -218,6 +224,38 @@ class NestedBlockState(object):
**kwargs)
return S
def move_vertex(self, v, s):
r"""Move vertex ``v`` to block ``s``."""
self.levels[0].move_vertex(v, s)
self._regen_levels()
def remove_vertex(self, v):
r"""Remove vertex ``v`` from its current group.
This optionally accepts a list of vertices to remove.
.. warning::
This will leave the state in an inconsistent state before the vertex
is returned to some other group, or if the same vertex is removed
twice.
"""
self.levels[0].remove_vertex(v)
self._regen_levels()
def add_vertex(self, v, r):
r"""Add vertex ``v`` to block ``r``.
This optionally accepts a list of vertices and blocks to add.
.. warning::
This can leave the state in an inconsistent state if a vertex is
added twice to the same group.
"""
self.levels[0].add_vertex(v, r)
self._regen_levels()
def get_edges_prob(self, edge_list, missing=True, entropy_args={}):
"""Compute the log-probability of the missing (or spurious if ``missing=False``)
edges given by ``edge_list`` (a list of ``(source, target)`` tuples, or
......
......@@ -74,15 +74,17 @@ class OverlapBlockState(BlockState):
def __init__(self, g, b=None, B=None, clabel=None, pclabel=None,
deg_corr=True, max_BE=1000, **kwargs):
kwargs = kwargs.copy()
# determine if there is a base graph, and overlapping structure
self.base_g = kwargs.get("base_g", None)
self.base_g = extract_arg(kwargs, "base_g", None)
# overlapping information
node_index = kwargs.get("node_index", None)
node_in_degs = kwargs.get("node_in_degs", None)
node_out_degs = kwargs.get("node_out_degs", None)
half_edges = kwargs.get("half_edges", None)
eindex = kwargs.get("eindex", None)
node_index = extract_arg(kwargs, "node_index", None)
node_in_degs = extract_arg(kwargs, "node_in_degs", None)
node_out_degs = extract_arg(kwargs, "node_out_degs", None)
half_edges = extract_arg(kwargs, "half_edges", None)
eindex = extract_arg(kwargs, "eindex", None)
if node_index is not None and self.base_g is None:
raise ValueError("Must specify base graph if node_index is specified...")
......@@ -203,6 +205,10 @@ class OverlapBlockState(BlockState):
self._abg = self.bg._get_any()
self._state = libinference.make_overlap_block_state(self, _get_rng())
if len(kwargs) > 0:
raise ValueError("unrecognized keyword arguments: " +
str(list(kwargs.keys())))
def __repr__(self):
return "<OverlapBlockState object with %d blocks,%s for graph %s, at 0x%x>" % \
(self.B, " degree corrected," if self.deg_corr else "",
......
......@@ -55,6 +55,12 @@ def dmask(d, ks):
del d[k]
return d
def extract_arg(kwargs, arg, default=None):
val = kwargs.get(arg, default)
if arg in kwargs:
del kwargs[arg]
return val
def check_verbose(verbose):
if isinstance(verbose, tuple):
return verbose[0] != False
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment