Commit 358ba2f4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

blockmodel: fix in/out-degree description length for directed graphs

parent 3dbe58ef
......@@ -2317,12 +2317,17 @@ public:
void check_node_counts()
{
#ifndef NDEBUG
vector<size_t> wr(num_vertices(_bg));
for (auto v : vertices_range(_g))
wr[_b[v]] += _vweight[v];
for (auto r : vertices_range(_bg))
assert(size_t(_wr[r]) == wr[r]);
if (_coupled_state != nullptr)
_coupled_state->check_node_counts();
#endif
}
template <class V>
......
......@@ -60,17 +60,19 @@ class partition_stats
{
public:
typedef gt_hash_map<pair<size_t,size_t>, int> map_t;
typedef gt_hash_map<size_t, int> map_t;
template <class Graph, class Vprop, class VWprop, class Eprop, class Degs,
class Vlist>
partition_stats(Graph& g, Vprop& b, Vlist&& vlist, size_t E, size_t B,
VWprop& vweight, Eprop& eweight, Degs& degs)
: _N(0), _E(E), _total_B(B)
: _directed(graph_tool::is_directed(g)), _N(0), _E(E), _total_B(B)
{
if constexpr (!use_rmap)
{
_hist.resize(B, nullptr);
if (_directed)
_hist_in.resize(B, nullptr);
_hist_out.resize(B, nullptr);
_total.resize(B);
_ep.resize(B);
_em.resize(B);
......@@ -85,7 +87,9 @@ public:
auto [kin, kout] = get_deg(v, eweight, degs, g);
auto n = vweight[v];
get_hist(r)[make_pair(kin, kout)] += n;
if (_directed)
get_hist<false>(r)[kin] += n;
get_hist<true>(r)[kout] += n;
_em[r] += kin * n;
_ep[r] += kout * n;
_total[r] += n;
......@@ -106,31 +110,42 @@ public:
_E(o._E),
_actual_B(o._actual_B),
_total_B(o._total_B),
_hist(o._hist),
_hist_in(o._hist_in),
_hist_out(o._hist_out),
_total(o._total),
_ep(o._ep),
_em(o._em)
{
for (size_t r = 0; r < _hist.size(); ++r)
typedef decltype(_hist_out) hist_t;
for (auto* h : std::array<hist_t*,2>({&_hist_out, &_hist_in}))
{
if (_hist[r] != nullptr)
_hist[r] = new map_t(*_hist[r]);
auto& hist = *h;
for (size_t r = 0; r < hist.size(); ++r)
{
if (hist[r] != nullptr)
hist[r] = new map_t(*hist[r]);
}
}
}
~partition_stats()
{
for (auto* h : _hist)
for (auto* h : _hist_in)
{
if (h != nullptr)
delete h;
}
for (auto* h : _hist_out)
{
if (h != nullptr)
delete h;
}
}
template <bool create=true>
template <bool out, bool create=true>
auto& get_hist(size_t r)
{
auto& h = _hist[r];
auto h = (out) ? _hist_out[r] : _hist_in[r];
if (h == nullptr)
{
if constexpr (!create)
......@@ -150,12 +165,14 @@ public:
_bmap.resize(r + 1, null);
size_t nr = _bmap[r];
if (nr == null)
nr = _bmap[r] = _hist.size();
nr = _bmap[r] = _hist_out.size();
r = nr;
}
if (r >= _hist.size())
if (r >= _hist_out.size())
{
_hist.resize(r + 1, nullptr);
if (_directed)
_hist_in.resize(r + 1, nullptr);
_hist_out.resize(r + 1, nullptr);
_total.resize(r + 1);
_ep.resize(r + 1);
_em.resize(r + 1);
......@@ -184,7 +201,13 @@ public:
size_t total = 0;
if (ks.empty())
{
for (auto& k_c : get_hist<false>(r))
if (_directed)
{
for (auto& k_c : get_hist<false, false>(r))
S -= xlogx_fast(k_c.second);
}
for (auto& k_c : get_hist<true, false>(r))
{
S -= xlogx_fast(k_c.second);
total += k_c.second;
......@@ -192,16 +215,28 @@ public:
}
else
{
auto& h = get_hist<false>(r);
auto& h_out = get_hist<true, false>(r);
auto& h_in = (_directed) ? get_hist<false, false>(r) : h_out;
for (auto& k : ks)
{
auto iter = h.find(k);
auto k_c = (iter != h.end()) ? iter->second : 0;
if (_directed)
{
auto iter = h_in.find(get<0>(k));
auto k_c = (iter != h_in.end()) ? iter->second : 0;
S -= xlogx(k_c);
}
auto iter = h_out.find(get<0>(k));
auto k_c = (iter != h_out.end()) ? iter->second : 0;
S -= xlogx(k_c);
}
total = _total[r];
}
S += xlogx_fast(total);
if (_directed)
S += 2 * xlogx_fast(total);
else
S += xlogx_fast(total);
}
return S;
}
......@@ -232,7 +267,13 @@ public:
size_t total = 0;
if (ks.empty())
{
for (auto& k_c : get_hist<false>(r))
if (_directed)
{
for (auto& k_c : get_hist<false, false>(r))
S -= lgamma_fast(k_c.second + 1);
}
for (auto& k_c : get_hist<true, false>(r))
{
S -= lgamma_fast(k_c.second + 1);
total += k_c.second;
......@@ -240,16 +281,29 @@ public:
}
else
{
auto& h = get_hist<false>(r);
auto& h_out = get_hist<true, false>(r);
auto& h_in = (_directed) ? get_hist<false, false>(r) : h_out;
for (auto& k : ks)
{
auto iter = h.find(k);
auto k_c = (iter != h.end()) ? iter->second : 0;
if (_directed)
{
auto iter = h_in.find(get<0>(k));
auto k_c = (iter != h_in.end()) ? iter->second : 0;
S -= lgamma_fast(k_c + 1);
}
auto iter = h_out.find(get<1>(k));
auto k_c = (iter != h_out.end()) ? iter->second : 0;
S -= lgamma_fast(k_c + 1);
}
total = _total[r];
}
S += lgamma_fast(total + 1);
if (_directed)
S += 2 * lgamma_fast(total + 1);
else
S += lgamma_fast(total + 1);
}
return S;
}
......@@ -439,15 +493,26 @@ public:
auto get_Sk = [&](size_t s, pair<size_t, size_t>& deg, int delta)
{
int nd = 0;
if (_hist[s] != nullptr)
if (_directed && _hist_in[s] != nullptr)
{
auto& h = *_hist[s];
auto iter = h.find(deg);
auto& h = *_hist_in[s];
auto iter = h.find(get<0>(deg));
if (iter != h.end())
nd = iter->second;
}
assert(nd + delta >= 0);
return -xlogx_fast(nd + delta);
double S = -xlogx_fast(nd + delta);
nd = 0;
if (_hist_out[s] != nullptr)
{
auto& h = *_hist_out[s];
auto iter = h.find(get<1>(deg));
if (iter != h.end())
nd = iter->second;
}
return S -xlogx_fast(nd + delta);
};
double S_b = 0, S_a = 0;
......@@ -461,8 +526,16 @@ public:
S_a += get_Sk(r, deg, diff * nk);
});
S_b += xlogx_fast(nr);
S_a += xlogx_fast(nr + dn);
if (_directed)
{
S_b += 2 * xlogx_fast(nr);
S_a += 2 * xlogx_fast(nr + dn);
}
else
{
S_b += xlogx_fast(nr);
S_a += xlogx_fast(nr + dn);
}
return S_a - S_b;
}
......@@ -523,15 +596,25 @@ public:
auto get_Sk = [&](pair<size_t, size_t>& deg, int delta)
{
int nd = 0;
if (_hist[r] != nullptr)
if (_directed && _hist_in[r] != nullptr)
{
auto& h = *_hist[r];
auto iter = h.find(deg);
auto& h = *_hist_in[r];
auto iter = h.find(get<0>(deg));
if (iter != h.end())
nd = iter->second;
}
assert(nd + delta >= 0);
return -lgamma_fast(nd + delta + 1);
double S = -lgamma_fast(nd + delta + 1);
nd = 0;
if (_hist_out[r] != nullptr)
{
auto& h = *_hist_out[r];
auto iter = h.find(get<1>(deg));
if (iter != h.end())
nd = iter->second;
}
return S -lgamma_fast(nd + delta + 1);
};
double S_b = 0, S_a = 0;
......@@ -581,19 +664,27 @@ public:
auto [kin, kout] = get_deg(v, eweight, degs, g);
auto n = vweight[v];
int dk = diff * n;
auto& h = get_hist(r);
auto deg = make_pair(kin, kout);
auto iter = h.insert({deg, 0}).first;
iter->second += dk;
if (iter->second == 0)
h.erase(iter);
if (h.empty())
{
delete _hist[r];
_hist[r] = nullptr;
}
_em[r] += dk * deg.first;
_ep[r] += dk * deg.second;
auto change_hist =
[&](auto& hist, auto& h, size_t k)
{
auto iter = h.insert({k, 0}).first;
iter->second += dk;
if (iter->second == 0)
h.erase(iter);
if (h.empty())
{
delete hist[r];
hist[r] = nullptr;
}
};
if (_directed)
change_hist(_hist_in, get_hist<false>(r), kin);
change_hist(_hist_out, get_hist<true>(r), kout);
_em[r] += dk * kin;
_ep[r] += dk * kout;
}
template <class Graph, class VWeight, class EWeight, class Degs>
......@@ -645,63 +736,23 @@ public:
_total_B++;
if constexpr (!use_rmap)
{
_hist.resize(_total_B);
if (_directed)
_hist_in.resize(_total_B);
_hist_out.resize(_total_B);
_total.resize(_total_B);
_ep.resize(_total_B);
_em.resize(_total_B);
}
}
template <class Graph, class VProp, class VWeight, class EWeight, class Degs>
bool check_degs(Graph& g, VProp& b, VWeight& vweight, EWeight& eweight, Degs& degs)
{
vector<map_t> dhist;
for (auto v : vertices_range(g))
{
auto [kin, kout] = get_deg(v, eweight, degs, g);
auto n = vweight[v];
auto r = get_r(b[v]);
if (r >= dhist.size())
dhist.resize(r + 1);
dhist[r][{kin, kout}] += n;
}
for (size_t r = 0; r < dhist.size(); ++r)
{
for (auto& kn : dhist[r])
{
auto count = (r >= _hist.size()) ? 0 : get_hist(r)[kn.first];
if (kn.second != count)
{
assert(false);
return false;
}
}
}
for (size_t r = 0; r < _hist.size(); ++r)
{
for (auto& kn : get_hist<false>(r))
{
auto count = (r >= dhist.size()) ? 0 : dhist[r][kn.first];
if (kn.second != count)
{
assert(false);
return false;
}
}
}
return true;
}
private:
bool _directed;
vector<size_t> _bmap;
size_t _N;
size_t _E;
size_t _actual_B;
size_t _total_B;
vector<map_t*> _hist;
vector<map_t*> _hist_in, _hist_out;
vector<int> _total;
vector<int> _ep;
vector<int> _em;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment