Commit 5bf396b3 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: Improve block entry delta computation

parent f44e89c4
......@@ -249,13 +249,59 @@ public:
get_move_entries(v, r, null_group, _m_entries,
std::forward<EFilt>(efilt));
auto eops = [&](auto&& mid_op, auto&& end_op)
apply_delta<Add,!Add>(_m_entries);
if (Add)
BlockState::add_partition_node(v, r);
else
BlockState::remove_partition_node(v, r);
if (!_rec_types.empty() &&
_rec_types[1] == weight_type::DELTA_T) // waiting times
{
if (_ignore_degrees[v] > 0)
{
entries_op(_m_entries, _emat,
[&](auto r, auto s, auto& me, auto& delta)
auto dt = out_degreeS()(v, _g, _rec[1]);
if (Add)
_brecsum[r] += dt;
else
_brecsum[r] -= dt;
}
}
}
template <class EFilt>
void move_vertex(size_t v, size_t r, size_t nr, EFilt&& efilt)
{
get_move_entries(v, r, nr, _m_entries, std::forward<EFilt>(efilt));
apply_delta<true, true>(_m_entries);
BlockState::remove_partition_node(v, r);
BlockState::add_partition_node(v, nr);
if (!_rec_types.empty() &&
_rec_types[1] == weight_type::DELTA_T) // waiting times
{
if (_ignore_degrees[v] > 0)
{
auto dt = out_degreeS()(v, _g, _rec[1]);
_brecsum[r] -= dt;
_brecsum[nr] += dt;
}
}
}
template <bool Add, bool Remove, class MEntries>
void apply_delta(MEntries& m_entries)
{
auto eops = [&](auto&& mid_op, auto&& end_op, auto&& skip)
{
entries_op(m_entries, _emat,
[&](auto r, auto s, auto& me, auto delta, auto& edelta)
{
if (get<0>(delta) == 0) // can happen with
return; // zero-weight edges
if (skip(delta, edelta))
return;
if (Add && me == _emat.get_null_edge())
{
......@@ -271,19 +317,19 @@ public:
_coupled_state->add_edge(me);
}
mid_op(me, delta);
mid_op(me, edelta);
this->_mrs[me] += get<0>(delta);
this->_mrp[r] += get<0>(delta);
this->_mrm[s] += get<0>(delta);
this->_mrs[me] += delta;
this->_mrp[r] += delta;
this->_mrm[s] += delta;
assert(this->_mrs[me] >= 0);
assert(this->_mrp[r] >= 0);
assert(this->_mrm[s] >= 0);
end_op(me, delta);
end_op(me, edelta);
if (!Add && this->_mrs[me] == 0)
if (Remove && this->_mrs[me] == 0)
{
this->_emat.remove_me(me, this->_bg);
if (_coupled_state != nullptr)
......@@ -296,37 +342,48 @@ public:
if (_rec_types.empty())
{
eops([](auto&, auto&){}, [](auto&, auto&){});
eops([](auto&, auto&){}, [](auto&, auto&){},
[](auto delta, auto&) { return delta == 0; });
}
else
{
auto end_op = [&](auto& me, auto& delta)
auto skip = [&](auto delta, auto& edelta)
{
if (delta != 0)
return false;
for (size_t i = 0; i < this->_rec_types.size(); ++i)
{
if (get<0>(edelta)[i] != 0)
return false;
if (this->_rec_types[i] == weight_type::REAL_NORMAL &&
get<1>(edelta)[i] != 0)
return false;
}
return true;
};
auto end_op = [&](auto& me, auto& edelta)
{
for (size_t i = 0; i < this->_rec_types.size(); ++i)
{
switch (this->_rec_types[i])
{
case weight_type::REAL_NORMAL: // signed weights
this->_bdrec[i][me] += get<2>(delta)[i];
[[gnu::fallthrough]];
default:
this->_brec[i][me] += get<1>(delta)[i];
}
this->_brec[i][me] += get<0>(edelta)[i];
if (this->_rec_types[i] == weight_type::REAL_NORMAL)
this->_bdrec[i][me] += get<1>(edelta)[i];
}
};
auto mid_op_BE =
[&](auto& me, auto&& delta)
[&](auto& me, auto& edelta)
{
auto mrs = this->_brec[0][me];
if (Add && mrs == 0 && mrs + get<1>(delta)[0] > 0)
if (Add && mrs == 0 && (mrs + get<0>(edelta)[0]) > 0)
{
_B_E++;
if (_coupled_state != nullptr)
_coupled_state->add_edge_rec(me);
}
if (!Add && mrs > 0 && mrs + get<1>(delta)[0] == 0)
if (Remove && mrs > 0 && (mrs + get<0>(edelta)[0]) == 0)
{
_B_E--;
if (_coupled_state != nullptr)
......@@ -336,17 +393,17 @@ public:
if (_rt != weight_type::REAL_NORMAL)
{
eops(mid_op_BE, end_op);
eops(mid_op_BE, end_op, skip);
}
else
{
auto mid_op =
[&](auto& me, auto&& delta)
[&](auto& me, auto& edelta)
{
auto& mrs = this->_brec[0][me];
mid_op_BE(me, delta);
mid_op_BE(me, edelta);
auto n_mrs = mrs + get<1>(delta)[0];
auto n_mrs = mrs + get<0>(edelta)[0];
if (n_mrs > 1)
{
......@@ -362,16 +419,16 @@ public:
if (this->_rec_types[i] != weight_type::REAL_NORMAL)
continue;
auto dx = \
(this->_bdrec[i][me] + get<2>(delta)[i]
(this->_bdrec[i][me] + get<1>(edelta)[i]
- (std::pow((this->_brec[i][me] +
get<1>(delta)[i]), 2) / n_mrs));
get<0>(edelta)[i]), 2) / n_mrs));
this->_recdx[i] += dx;
}
}
if (mrs > 1)
{
if (!Add && n_mrs < 2)
if (Remove && n_mrs < 2)
{
_B_E_D--;
if (_B_E_D == 0 && this->_Lrecdx[0] >= 0)
......@@ -394,17 +451,17 @@ public:
{
_recx2[i] -= std::pow(this->_brec[i][me], 2);
_recx2[i] += std::pow(this->_brec[i][me] +
get<1>(delta)[i], 2);
get<0>(edelta)[i], 2);
}
}
};
auto coupled_end_op = [&](auto& me, auto& delta)
auto coupled_end_op = [&](auto& me, auto& edelta)
{
end_op(me, delta);
end_op(me, edelta);
if (_coupled_state != nullptr)
_coupled_state->update_edge_rec(me, get<1>(delta));
_coupled_state->update_edge_rec(me, get<0>(edelta));
};
if (_Lrecdx[0] >= 0)
......@@ -413,7 +470,7 @@ public:
_Lrecdx[i+1] -= _recdx[i] * _B_E_D;
}
eops(mid_op, coupled_end_op);
eops(mid_op, coupled_end_op, skip);
if (_Lrecdx[0] >= 0)
{
......@@ -423,25 +480,47 @@ public:
}
}
if (!_rec_types.empty() &&
_rec_types[1] == weight_type::DELTA_T) // waiting times
if (_coupled_state != nullptr)
{
if (_ignore_degrees[v] > 0)
{
auto dt = out_degreeS()(v, _g, _rec[1]);
if (Add)
_brecsum[r] += dt;
else
_brecsum[r] -= dt;
}
_p_entries.clear();
entries_op(m_entries, _emat,
[&](auto r, auto s, auto&, auto delta, auto& edelta)
{
if (delta == 0)
return;
_p_entries.emplace_back(r, s, delta, get<0>(edelta));
});
if (!_p_entries.empty())
_coupled_state->propagate_delta(m_entries.get_move().first,
m_entries.get_move().second,
_p_entries);
}
if (Add)
BlockState::add_partition_node(v, r);
else
BlockState::remove_partition_node(v, r);
}
void propagate_delta(size_t u, size_t v,
std::vector<std::tuple<size_t, size_t, int,
std::vector<double>>>& entries)
{
size_t r = (u != null_group) ? _b[u] : null_group;
size_t s = (v != null_group) ? _b[v] : null_group;
_m_entries.set_move(r, s, num_vertices(_bg));
for (auto& rsd : entries)
_m_entries.template insert_delta<true>(_b[get<0>(rsd)], _b[get<1>(rsd)],
get<2>(rsd));
apply_delta<true, true>(_m_entries);
entries.clear();
entries_op(_m_entries, _emat,
[&](auto r, auto s, auto&, auto delta, auto& edelta)
{
if (delta == 0)
return;
entries.emplace_back(r, s, delta, get<0>(edelta));
});
if (_coupled_state != nullptr && !entries.empty())
_coupled_state->propagate_delta(r, s, entries);
}
void add_edge(const GraphInterface::edge_t& e)
{
......@@ -860,8 +939,7 @@ public:
}
}
remove_vertex(v, r, [](auto&) {return false;});
add_vertex(v, nr, [](auto&) {return false;});
move_vertex(v, r, nr, [](auto&) {return false;});
}
void move_vertex(size_t v, size_t nr)
......@@ -1454,7 +1532,7 @@ public:
{
int dB_E = 0;
entries_op(m_entries, this->_emat,
[&](auto, auto, auto& me, auto& delta)
[&](auto, auto, auto& me, auto delta, auto& edelta)
{
double ers = 0;
double xrs = 0;
......@@ -1463,8 +1541,8 @@ public:
ers = this->_brec[0][me];
xrs = this->_brec[i][me];
}
auto d = get<1>(delta)[0];
auto dx = get<1>(delta)[i];
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
dS -= -w_log_P(ers, xrs);
dS += -w_log_P(ers + d, xrs + dx);
......@@ -1473,9 +1551,9 @@ public:
size_t ers = 0;
if (me != _emat.get_null_edge())
ers = this->_mrs[me];
if (ers == 0 && get<0>(delta) > 0)
if (ers == 0 && delta > 0)
dB_E++;
if (ers > 0 && ers + get<0>(delta) == 0)
if (ers > 0 && ers + delta == 0)
dB_E--;
}
});
......@@ -1554,7 +1632,7 @@ public:
double dBx2 = 0;
_dBdx[i] = 0;
entries_op(m_entries, _emat,
[&](auto, auto, auto& me, auto& delta)
[&](auto, auto, auto& me, auto, auto& edelta)
{
double ers = 0;
double xrs = 0, x2rs = 0;
......@@ -1564,9 +1642,9 @@ public:
xrs = this->_brec[i][me];
x2rs = this->_bdrec[i][me];
}
auto d = get<1>(delta)[0];
auto dx = get<1>(delta)[i];
auto dx2 = get<2>(delta)[i];
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
auto dx2 = get<1>(edelta)[i];
dS -= -signed_w_log_P(ers, xrs, x2rs,
wp[0], wp[1],
wp[2], wp[3],
......@@ -1580,7 +1658,7 @@ public:
if (std::isnan(wp[0]) &&
std::isnan(wp[1]))
{
auto n_ers = ers + get<1>(delta)[0];
auto n_ers = ers + get<0>(edelta)[0];
if (ers == 0 && n_ers > 0)
dB_E++;
if (ers > 0 && n_ers == 0)
......@@ -1725,11 +1803,10 @@ public:
auto& recs_entries = m_entries._recs_entries;
recs_entries.clear();
entries_op(m_entries, _emat,
[&](auto r, auto s, auto& me, auto& delta)
[&](auto r, auto s, auto& me, auto delta, auto& edelta)
{
recs_entries.emplace_back(r, s, me,
get<0>(delta),
get<1>(delta));
delta, get<0>(edelta));
});
scoped_lock lck(_lock);
......@@ -1805,12 +1882,11 @@ public:
auto w_entries_op = [&](auto&& w_log_P)
{
entries_op(_m_entries, _emat,
[&](auto, auto, auto& me, auto& delta)
[&](auto, auto, auto& me, auto d, auto&)
{
int ers = 0;
if (me != _emat.get_null_edge())
ers = this->_brec[0][me];
auto d = get<0>(delta);
if (d != 0)
{
dS -= -w_log_P(ers, me);
......@@ -1851,7 +1927,7 @@ public:
int dB_E_D = 0;
double drecdx = 0;
entries_op(_m_entries, _emat,
[&](auto, auto, auto& me, auto& delta)
[&](auto, auto, auto& me, auto d, auto& edelta)
{
int ers = 0;
double x = 0, x2 = 0;
......@@ -1861,8 +1937,7 @@ public:
x = this->_brec[i][me];
x2 = this->_bdrec[i][me];
}
auto d = get<0>(delta);
auto dx2 = get<1>(delta)[i];
auto dx2 = get<0>(edelta)[i];
dS -= -signed_w_log_P(ers, x, x2, wp[0],
wp[1], wp[2], wp[3],
this->_epsilon[i]);
......@@ -2070,10 +2145,10 @@ public:
if (reverse)
{
int dts = get<0>(m_entries.get_delta(t, s));
int dts = m_entries.get_delta(t, s);
int dst = dts;
if (is_directed::apply<g_t>::type::value)
dst = get<0>(m_entries.get_delta(s, t));
dst = m_entries.get_delta(s, t);
mts += dts;
mst += dst;
......@@ -2634,10 +2709,13 @@ public:
std::vector<partition_stats_t> _partition_stats;
std::vector<size_t> _bmap;
typedef EntrySet<g_t, bg_t, int, std::vector<double>,
typedef EntrySet<g_t, bg_t, std::vector<double>,
std::vector<double>> m_entries_t;
m_entries_t _m_entries;
std::vector<std::tuple<size_t, size_t, int, std::vector<double>>>
_p_entries;
BlockStateVirtualBase* _coupled_state = nullptr;
entropy_args_t _coupled_entropy_args;
......
......@@ -130,8 +130,11 @@ public:
{
_r_out_field.resize(B, _null);
_nr_out_field.resize(B, _null);
_r_in_field.resize(B, _null);
_nr_in_field.resize(B, _null);
if (is_directed::apply<Graph>::type::value)
{
_r_in_field.resize(B, _null);
_nr_in_field.resize(B, _null);
}
}
void set_move(size_t r, size_t nr, size_t B)
......@@ -142,53 +145,91 @@ public:
{
_r_out_field.resize(B, _null);
_nr_out_field.resize(B, _null);
_r_in_field.resize(B, _null);
_nr_in_field.resize(B, _null);
if (is_directed::apply<Graph>::type::value)
{
_r_in_field.resize(B, _null);
_nr_in_field.resize(B, _null);
}
}
}
size_t& get_field(size_t s, size_t t)
const pair<size_t, size_t>& get_move() { return _rnr; }
template <bool First, bool Source>
size_t& get_field_rnr(size_t s, size_t t)
{
if (!is_directed::apply<Graph>::type::value && s > t)
std::swap(s, t);
auto& out_field = First ? _r_out_field : _nr_out_field;
if (is_directed::apply<Graph>::type::value)
{
auto& in_field = (First ? _r_in_field : _nr_in_field);
return (Source || s == t) ? out_field[t] : in_field[s];
}
else
{
return (Source) ? out_field[t] : out_field[s];
}
}
size_t& get_field(size_t s, size_t t)
{
if (s == _rnr.first)
return _r_out_field[t];
else if (s == _rnr.second)
return _nr_out_field[t];
else if (t == _rnr.first)
return _r_in_field[s];
else if (t == _rnr.second)
return _nr_in_field[s];
else
return _dummy;
return get_field_rnr<true, true>(s, t);
if (t == _rnr.first)
return get_field_rnr<true, false>(s, t);
if (s == _rnr.second)
return get_field_rnr<false, true>(s, t);
if (t == _rnr.second)
return get_field_rnr<false, false>(s, t);
return _dummy;
}
template <bool Add, class... DVals>
__attribute__((flatten))
void insert_delta(size_t s, size_t t, DVals&&... delta)
void insert_delta_dispatch(size_t s, size_t t, size_t& f, int d, DVals&&... delta)
{
auto& f = get_field(s, t);
if (f == _null)
{
f = _entries.size();
_entries.emplace_back(s, t);
_delta.emplace_back();
if (sizeof...(delta) > 0)
_edelta.emplace_back();
}
if (Add)
tuple_op(_delta[f], [&](auto& r, auto& v){ r += v; },
{
_delta[f] += d;
tuple_op(_edelta[f], [&](auto& r, auto& v){ r += v; },
delta...);
}
else
tuple_op(_delta[f], [&](auto& r, auto& v){ r -= v; },
{
_delta[f] -= d;
tuple_op(_edelta[f], [&](auto& r, auto& v){ r -= v; },
delta...);
}
}
template <bool First, bool Source, bool Add, class... DVals>
__attribute__((flatten))
void insert_delta_rnr(size_t s, size_t t, int d, DVals&&... delta)
{
auto& f = get_field_rnr<First, Source>(s, t);
insert_delta_dispatch<Add>(s, t, f, d, std::forward<DVals>(delta)...);
}
const auto& get_delta(size_t r, size_t s)
template <bool Add, class... DVals>
__attribute__((flatten))
void insert_delta(size_t s, size_t t, int d, DVals&&... delta)
{
auto& f = get_field(s, t);
insert_delta_dispatch<Add>(s, t, f, d, std::forward<DVals>(delta)...);
}
int get_delta(size_t r, size_t s)
{
size_t f = get_field(r, s);
if (f == _null)
return _null_delta;
return 0;
return _delta[f];
}
......@@ -203,12 +244,14 @@ public:
}
_entries.clear();
_delta.clear();
_edelta.clear();
_mes.clear();
_recs_entries.clear();
}
const vector<pair<size_t, size_t> >& get_entries() { return _entries; }
const vector<std::tuple<EVals...>>& get_delta() { return _delta; }
const vector<pair<size_t, size_t>>& get_entries() { return _entries; }
const vector<int>& get_delta() { return _delta; }
const vector<std::tuple<EVals...>>& get_edelta() { return _edelta; }
template <class Emat>
vector<bedge_t>& get_mes(Emat& emat)
......@@ -232,7 +275,7 @@ public:
return _mes[field];
}