Commit bc21f238 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix issue with CovariateBlockState.merge_layers() when layers == True

parent c6814980
......@@ -101,14 +101,6 @@ class CovariateBlockState(BlockState):
self.C = ec.fa.max() + 1
self.layers = layers
if self.layers:
# we need to include the membership of the nodes in each layer
be = group_vector_property([self.ec, self.ec])
lstate = OverlapBlockState(self.g, b=be)
self.layer_entropy = lstate.entropy(dl=True, edges_dl=False) - lstate.entropy(dl=False)
else:
self.layer_entropy = 0
if "max_BE" in kwargs:
del kwargs["max_BE"]
max_BE = 0
......@@ -213,6 +205,8 @@ class CovariateBlockState(BlockState):
self.overlap_stats = self.total_state.overlap_stats
self.__layer_entropy = None
if _bm_test():
assert self.mrs.fa.sum() == self.eweight.fa.sum(), "inconsistent mrs!"
......@@ -439,6 +433,8 @@ class CovariateBlockState(BlockState):
old_bmap = self.bmap.copy()
self.bmap.del_c(l_src + 1)
self.__bg = None
old_layer_entropy = self.__layer_entropy
self.__layer_entropy = None
yield
......@@ -452,6 +448,7 @@ class CovariateBlockState(BlockState):
self.base_ec.a[:] = old_base_ec.a
self.C += 1
self.bmap = old_bmap
self.__layer_entropy = old_layer_entropy
def __getstate__(self):
state = dict(g=self.g,
......@@ -609,6 +606,19 @@ class CovariateBlockState(BlockState):
else:
return self.total_state.get_majority_blocks()
def __get_layer_entropy(self):
if self.__layer_entropy is None:
if self.layers:
# we need to include the membership of the nodes in each layer
g = self.base_g if self.overlap else self.g
ec = self.base_ec if self.overlap else self.ec
be = group_vector_property([ec, ec])
lstate = OverlapBlockState(g, b=be, deg_corr=False)
self.__layer_entropy = lstate.entropy(dl=True, edges_dl=False) - lstate.entropy(dl=False)
else:
self.__layer_entropy = 0
return self.__layer_entropy
def entropy(self, complete=True, dl=False, partition_dl=True, edges_dl=True,
degree_dl=True, dense=False, multigraph=True, norm=False,
dl_ent=False, **kwargs):
......@@ -684,7 +694,7 @@ class CovariateBlockState(BlockState):
norm=False)
if dl:
S += self.layer_entropy
S += self.__get_layer_entropy()
if norm:
S /= self.E
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment