Commit c91eb270 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: add support for discrete edge weights

parent 81bcf5c4
Pipeline #215 failed with stage
in 2209 minutes and 53 seconds
......@@ -233,6 +233,8 @@ void export_blockmodel_state()
.value("none", weight_type::NONE)
.value("positive", weight_type::POSITIVE)
.value("signed", weight_type::SIGNED)
.value("discrete_geometric", weight_type::DISCRETE_GEOMETRIC)
.value("discrete_poisson", weight_type::DISCRETE_POISSON)
.value("delta_t", weight_type::DELTA_T);
def("make_block_state", &make_block_state);
......
......@@ -50,6 +50,8 @@ enum weight_type
NONE,
POSITIVE,
SIGNED,
DISCRETE_GEOMETRIC,
DISCRETE_POISSON,
DELTA_T
};
......@@ -165,6 +167,8 @@ public:
switch (_rec_type)
{
case weight_type::POSITIVE: // positive weights
case weight_type::DISCRETE_GEOMETRIC:
case weight_type::DISCRETE_POISSON:
mv_entries(gs._rec);
break;
case weight_type::SIGNED: // positive and negative weights
......@@ -214,6 +218,8 @@ public:
case weight_type::SIGNED: // signed weights
this->_bdrec[me] += get<2>(delta);
case weight_type::POSITIVE: // positive weights
case weight_type::DISCRETE_GEOMETRIC:
case weight_type::DISCRETE_POISSON:
this->_brec[me] += get<1>(delta);
}
});
......@@ -353,6 +359,8 @@ public:
case weight_type::SIGNED: // signed weights
_bdrec[me] -= _drec[e];
case weight_type::POSITIVE: // positive weights
case weight_type::DISCRETE_GEOMETRIC:
case weight_type::DISCRETE_POISSON:
_brec[me] -= _rec[e];
}
......@@ -447,6 +455,8 @@ public:
case weight_type::SIGNED: // signed weights
_bdrec[me] += _drec[e];
case weight_type::POSITIVE: // positive weights
case weight_type::DISCRETE_GEOMETRIC:
case weight_type::DISCRETE_POISSON:
_brec[me] += _rec[e];
}
}
......@@ -978,26 +988,44 @@ public:
dS += ps.get_delta_edges_dl(v, r, nr, gs._vweight, gs._g);
}
auto positive_entries_op = [&](auto&& w_log_P)
{
entries_op(m_entries, this->_emat,
[&](auto, auto, auto& me, auto& delta)
{
size_t ers = 0;
double xrs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_mrs[me];
xrs = this->_brec[me];
}
auto d = get<0>(delta);
auto dx = get<1>(delta);
dS -= -w_log_P(ers, xrs);
dS += -w_log_P(ers + d, xrs + dx);
});
};
switch (_rec_type)
{
case weight_type::POSITIVE: // positive weights
entries_op(m_entries, _emat,
[&](auto, auto, auto& me, auto& delta)
{
size_t ers = 0;
double xrs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_mrs[me];
xrs = this->_brec[me];
}
auto d = get<0>(delta);
auto dx = get<1>(delta);
dS -= -positive_w_log_P(ers, xrs,
this->_alpha, this->_beta);
dS += -positive_w_log_P(ers + d, xrs + dx,
this->_alpha, this->_beta);
});
positive_entries_op([&](auto N, auto x)
{ return positive_w_log_P(N, x,
this->_alpha,
this->_beta); });
break;
case weight_type::DISCRETE_GEOMETRIC:
positive_entries_op([&](auto N, auto x)
{ return geometric_w_log_P(N, x,
this->_alpha,
this->_beta); });
break;
case weight_type::DISCRETE_POISSON:
positive_entries_op([&](auto N, auto x)
{ return poisson_w_log_P(N, x,
this->_alpha,
this->_beta); });
break;
case weight_type::SIGNED: // positive and negative weights
entries_op(m_entries, _emat,
......@@ -1405,6 +1433,22 @@ public:
S += -positive_w_log_P(ers, xrs, _alpha, _beta);
}
break;
case weight_type::DISCRETE_GEOMETRIC:
for (auto me : edges_range(_bg))
{
auto ers = _mrs[me];
auto xrs = _brec[me];
S += -geometric_w_log_P(ers, xrs, _alpha, _beta);
}
break;
case weight_type::DISCRETE_POISSON:
for (auto me : edges_range(_bg))
{
auto ers = _mrs[me];
auto xrs = _brec[me];
S += -poisson_w_log_P(ers, xrs, _alpha, _beta);
}
break;
case weight_type::SIGNED: // positive and negative weights
for (auto me : edges_range(_bg))
{
......
......@@ -216,6 +216,7 @@ inline double eterm_dense(size_t r, size_t s, int ers, double wr_r,
// Weighted entropy terms
// exponential
template <class DT>
double positive_w_log_P(DT N, double x, double alpha, double beta)
{
......@@ -225,6 +226,7 @@ double positive_w_log_P(DT N, double x, double alpha, double beta)
+ alpha * log(beta) - (alpha + N) * log(beta + x);
}
// normal
template <class DT>
double signed_w_log_P(DT N, double x, double v, double m0, double k0, double v0,
double nu0)
......@@ -239,6 +241,28 @@ double signed_w_log_P(DT N, double x, double v, double m0, double k0, double v0,
- (nu_n / 2.) * log(nu_n * v_n) - (N/2.) * log(M_PI);
}
// discrete: geometric
template <class DT>
double geometric_w_log_P(DT N, double x, double alpha, double beta)
{
if (N == 0)
return 0.;
return boost::math::lgamma(alpha + beta)
+ boost::math::lgamma(alpha + 1) + boost::math::lgamma(beta + 1)
- boost::math::lgamma(alpha) - boost::math::lgamma(beta)
- boost::math::lgamma(x + alpha + beta + 1);
}
// discrete: Poisson
template <class DT>
double poisson_w_log_P(DT N, double x, double alpha, double beta)
{
if (N == 0)
return 0.;
return boost::math::lgamma(x + alpha)
- boost::math::lgamma(x + 1) - boost::math::lgamma(alpha)
+ alpha * log(beta) - (x + alpha) * log(beta + 1);
}
// ===============
// Partition stats
......
......@@ -155,13 +155,14 @@ class BlockState(object):
Vertex multiplicities (for block graphs).
rec : :class:`~graph_tool.PropertyMap` (optional, default: ``None``)
Real-valued edge covariates.
rec_type : `"positive"`, `"signed"` or `None` (optional, default: ``None``)
rec_type : `"positive"`, `"signed"`, `"discrete_geometric"`, `"discrete_poisson"` or `None` (optional, default: ``None``)
Type of edge covariates. If not specified, it will be guessed from
``rec``.
rec_params : ``dict`` (optional, default: ``{}``)
Model hyperparameters for real-valued covariates. This should be a
``dict`` with keys in the list ``["alpha", "beta"]`` if ``rec_type ==
positive`` or ``["m0", "k0", "v0". "nu0"]`` if ``rec_type == signed``.
``dict`` with keys in the list ``["alpha", "beta"]`` if ``rec_type`` is
one of ``"positive"``, ``"discrete-geometric"``, ``"discrete-poisson"``
or ``["m0", "k0", "v0". "nu0"]`` if ``rec_type == "signed"``.
b : :class:`~graph_tool.PropertyMap` (optional, default: ``None``)
Initial block labels on the vertices. If not supplied, it will be
randomly sampled.
......@@ -184,6 +185,7 @@ class BlockState(object):
max_BE : ``int`` (optional, default: ``1000``)
If the number of blocks exceeds this value, a sparse matrix is used for
the block graph. Otherwise a dense matrix will be used.
"""
def __init__(self, g, eweight=None, vweight=None, rec=None, rec_type=None,
......@@ -340,6 +342,10 @@ class BlockState(object):
self.rec_type = libinference.rec_type.positive
elif rec_type == "signed":
self.rec_type = libinference.rec_type.signed
elif rec_type == "discrete-geometric":
self.rec_type = libinference.rec_type.discrete_geometric
elif rec_type == "discrete-poisson":
self.rec_type = libinference.rec_type.discrete_poisson
elif rec_type == "delta_t":
self.rec_type = libinference.rec_type.delta_t
else:
......@@ -359,11 +365,14 @@ class BlockState(object):
self.rec_params = dict(m0=self.rec.fa.mean(), k0=1,
v0=self.rec.fa.std() ** 2, nu0=3)
if self.is_weighted:
idx = self.eweight.fa > 0
self.rec_params.update(dict(alpha=1, beta=self.rec.fa[idx].mean()))
if self.rec_type == libinference.rec_type.discrete_geometric:
self.rec_params.update(dict(alpha=1, beta=self.rec.fa.mean() - 1))
elif self.rec_type == libinference.rec_type.discrete_poisson:
self.rec_params.update(dict(alpha=1, beta=1./self.rec.fa.mean()))
else:
self.rec_params.update(dict(alpha=1, beta=self.rec.fa.mean()))
self.rec_params.update(rec_params)
self.__dict__.update(self.rec_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment