Commit f44436b4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Implement __copy__ and __deepcopy__ for Nested/Overlap/Covariate/BlockState

parent da569528
...@@ -225,13 +225,24 @@ class BlockState(object): ...@@ -225,13 +225,24 @@ class BlockState(object):
self.partition_stats = libcommunity.partition_stats() self.partition_stats = libcommunity.partition_stats()
def __copy__(self):
return self.copy()
def __deepcopy__(self, memo):
g = self.g.copy()
eweight = g.own_property(self.eweight.copy())
vweight = g.own_property(self.vweight.copy())
clabel = g.own_property(self.clabel.copy())
b = g.own_property(self.b.copy())
return self.copy(g=g, eweight=eweight, vweight=vweight, b=b,
clabel=clabel)
def copy(self, g=None, eweight=None, vweight=None, b=None, B=None, def copy(self, g=None, eweight=None, vweight=None, b=None, B=None,
deg_corr=None, clabel=None, overlap=False, **kwargs): deg_corr=None, clabel=None, overlap=False, **kwargs):
r"""Copies the block state. The parameters override the state properties, and r"""Copies the block state. The parameters override the state properties, and
have the same meaning as in the constructor. If ``overlap=True`` an have the same meaning as in the constructor. If ``overlap=True`` an
instance of :class:`~graph_tool.community.OverlapBlockState` is instance of :class:`~graph_tool.community.OverlapBlockState` is
returned.""" returned. This is by default a shallow copy."""
if not overlap: if not overlap:
state = BlockState(self.g if g is None else g, state = BlockState(self.g if g is None else g,
......
...@@ -462,6 +462,26 @@ class CovariateBlockState(BlockState): ...@@ -462,6 +462,26 @@ class CovariateBlockState(BlockState):
self.__init__(**state) self.__init__(**state)
return state return state
def __copy__(self):
return self.copy()
def __deepcopy__(self, memo):
if not self.overlap:
g = self.g.copy()
eweight = g.own_property(self.eweight.copy())
vweight = g.own_property(self.vweight.copy())
clabel = g.own_property(self.clabel.copy())
b = g.own_property(self.b.copy())
ec = g.own_property(self.ec.copy())
return self.copy(g=g, ec=ec, eweight=eweight, vweight=vweight, b=b,
clabel=clabel)
else:
g = self.base_g.copy()
clabel = self.clabel
b = self.b
ec = g.own_property(self.base_ec.copy())
return self.copy(g=g, ec=ec, b=b.fa, clabel=clabel.fa)
def copy(self, g=None, eweight=None, vweight=None, b=None, B=None, def copy(self, g=None, eweight=None, vweight=None, b=None, B=None,
deg_corr=None, clabel=None, overlap=None, layers=None, ec=None): deg_corr=None, clabel=None, overlap=None, layers=None, ec=None):
r"""Copies the block state. The parameters override the state properties, and r"""Copies the block state. The parameters override the state properties, and
......
...@@ -158,6 +158,20 @@ class NestedBlockState(object): ...@@ -158,6 +158,20 @@ class NestedBlockState(object):
" degree corrected," if self.deg_corr else "", " degree corrected," if self.deg_corr else "",
str(self.g), len(self.levels), str([(s.N, s.B) for s in self.levels]), id(self)) str(self.g), len(self.levels), str([(s.N, s.B) for s in self.levels]), id(self))
def __copy__(self):
return self.copy()
def __deepcopy__(self, memo):
g = self.g.copy()
eweight = g.own_property(self.eweight.copy()) if self.eweight is not None else None
vweight = g.own_property(self.vweight.copy()) if self.vweight is not None else None
clabel = g.own_property(self.clabel.copy()) if self.clabel is not None else None
ec = g.own_property(self.ec.copy()) if self.ec is not None else None
bstack = self.get_bstack()
return self.copy(g=g, eweight=eweight, vweight=vweight, clabel=clabel,
ec=ec, bs=[s.vp.b.a for s in bstack])
def copy(self, g=None, eweight=None, vweight=None, bs=None, ec=None, def copy(self, g=None, eweight=None, vweight=None, bs=None, ec=None,
layers=None, deg_corr=None, overlap=None, clabel=None, **kwargs): layers=None, deg_corr=None, overlap=None, clabel=None, **kwargs):
r"""Copies the block state. The parameters override the state properties, and r"""Copies the block state. The parameters override the state properties, and
......
...@@ -248,12 +248,21 @@ class OverlapBlockState(BlockState): ...@@ -248,12 +248,21 @@ class OverlapBlockState(BlockState):
else: else:
self.partition_stats = libcommunity.overlap_partition_stats() self.partition_stats = libcommunity.overlap_partition_stats()
def __copy__(self):
return self.copy()
def __deepcopy__(self, memo):
g = self.base_g.copy()
clabel = self.clabel.copy()
b = self.b
return self.copy(g=g, b=b.fa, clabel=clabel.fa)
def copy(self, g=None, eweight=None, vweight=None, b=None, B=None, def copy(self, g=None, eweight=None, vweight=None, b=None, B=None,
deg_corr=None, clabel=None, overlap=True): deg_corr=None, clabel=None, overlap=True):
r"""Copies the block state. The parameters override the state properties, and r"""Copies the block state. The parameters override the state properties, and
have the same meaning as in the constructor. If ``overlap=False`` an have the same meaning as in the constructor. If ``overlap=False`` an
instance of :class:`~graph_tool.community.BlockState` is returned.""" instance of :class:`~graph_tool.community.BlockState` is returned. This
is by default a shallow copy."""
if overlap: if overlap:
state = OverlapBlockState(self.g if g is None else g, state = OverlapBlockState(self.g if g is None else g,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment