Commit f606acc2 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: Slightly improve entry computation performance

parent e7363032
......@@ -88,15 +88,15 @@ void tuple_op_imp(T&, OP&&)
template <size_t i, class T, class OP, class Ti, class... Ts>
void tuple_op_imp(T& tuple, OP&& op, Ti&& v, Ts&&... vals)
{
op(get<i>(tuple), v);
tuple_op_imp<i+1>(tuple, op, vals...);
op(get<i>(tuple), std::forward<Ti>(v));
tuple_op_imp<i+1>(tuple, std::forward<OP>(op), std::forward<Ts>(vals)...);
}
template <class OP, class T, class... Ts>
__attribute__((flatten))
void tuple_op(T& tuple, OP&& op, Ts&&... vals)
{
tuple_op_imp<0>(tuple, op, vals...);
tuple_op_imp<0>(tuple, std::forward<OP>(op), std::forward<Ts>(vals)...);
}
namespace detail {
......@@ -191,21 +191,21 @@ public:
f = _entries.size();
_entries.emplace_back(s, t);
_delta.emplace_back();
// if (sizeof...(delta) > 0)
if (sizeof...(delta) > 0)
_edelta.emplace_back();
}
if (Add)
{
_delta[f] += d;
tuple_op(_edelta[f], [&](auto& r, auto& v){ r += v; },
delta...);
tuple_op(_edelta[f], [&](auto&& r, auto&& v){ r += v; },
std::forward<DVals>(delta)...);
}
else
{
_delta[f] -= d;
tuple_op(_edelta[f], [&](auto& r, auto& v){ r -= v; },
delta...);
tuple_op(_edelta[f], [&](auto&& r, auto&& v){ r -= v; },
std::forward<DVals>(delta)...);
}
}
......@@ -251,8 +251,7 @@ public:
const vector<pair<size_t, size_t>>& get_entries() { return _entries; }
const vector<int>& get_delta() { return _delta; }
const vector<std::tuple<EVals...>>& get_edelta() { //_edelta.resize(_delta.size());
return _edelta; }
const vector<std::tuple<EVals...>>& get_edelta() { _edelta.resize(_delta.size()); return _edelta; }
template <class Emat>
vector<bedge_t>& get_mes(Emat& emat)
......@@ -317,7 +316,7 @@ void modify_entries(Vertex v, Vertex r, Vertex nr, Vprop& _b, Graph& g,
int self_weight = 0;
if (!graph_tool::is_directed(g) && sizeof...(Eprops) > 0)
{
tuple_apply([&](auto&&... vals)
tuple_apply([&](auto&... vals)
{
auto op = [](auto& x) -> auto& { x *= 0; return x; };
auto f = [](auto&...) {};
......@@ -353,7 +352,7 @@ void modify_entries(Vertex v, Vertex r, Vertex nr, Vprop& _b, Graph& g,
if ((u == v || is_loop(v)) && !graph_tool::is_directed(g))
{
self_weight += ew;
tuple_op(eself_weight, [&](auto& x, auto& val){ x += val; },
tuple_op(eself_weight, [&](auto&& x, auto&& val){ x += val; },
make_vadapter(eprops, e)...);
}
}
......@@ -362,7 +361,7 @@ void modify_entries(Vertex v, Vertex r, Vertex nr, Vprop& _b, Graph& g,
{
if (sizeof...(Eprops) > 0)
{
tuple_apply([&](auto&&... vals)
tuple_apply([&](auto&... vals)
{
auto op = [](auto& x) -> auto& { x /= 2; return x; };
auto f = [](auto&...) {};
......@@ -452,6 +451,23 @@ void move_entries(Vertex v, size_t r, size_t nr, VProp& _b, Graph& g,
// operation on a set of entries
template <class MEntries, class EMat, class OP>
void entries_op(MEntries& m_entries, EMat& emat, OP&& op)
{
const auto& entries = m_entries.get_entries();
const auto& delta = m_entries.get_delta();
auto& mes = m_entries.get_mes(emat);
for (size_t i = 0; i < entries.size(); ++i)
{
auto& entry = entries[i];
auto er = entry.first;
auto es = entry.second;
op(er, es, mes[i], delta[i]);
}
}
// operation on a set of entries, with edge covariates
template <class MEntries, class EMat, class OP>
void wentries_op(MEntries& m_entries, EMat& emat, OP&& op)
{
const auto& entries = m_entries.get_entries();
const auto& delta = m_entries.get_delta();
......@@ -473,7 +489,7 @@ double entries_dS(MEntries& m_entries, Eprop& mrs, EMat& emat, BGraph& bg)
{
double dS = 0;
entries_op(m_entries, emat,
[&](auto r, auto s, auto& me, auto d, auto&)
[&](auto r, auto s, auto& me, auto d)
{
size_t ers = 0;
if (me != emat.get_null_edge())
......
......@@ -36,7 +36,7 @@ double virtual_move_covariate(size_t v, size_t r, size_t s, State& state,
double dS = 0;
entries_op(m_entries, state._emat,
[&](auto, auto, auto& me, auto d, auto&)
[&](auto, auto, auto& me, auto d)
{
int ers = (me != state._emat.get_null_edge()) ?
state._mrs[me] : 0;
......
......@@ -658,32 +658,32 @@ public:
auto&& w_log_prior)
{
int dB_E = 0;
entries_op(m_entries, this->_emat,
[&](auto, auto, auto& me, auto delta, auto& edelta)
{
double ers = 0;
double xrs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_brec[0][me];
xrs = this->_brec[i][me];
}
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
dS -= -w_log_P(ers, xrs);
dS += -w_log_P(ers + d, xrs + dx);
if (ea.recs_dl)
{
size_t ers = 0;
if (me != _emat.get_null_edge())
ers = this->_mrs[me];
if (ers == 0 && delta > 0)
dB_E++;
if (ers > 0 && ers + delta == 0)
dB_E--;
}
});
wentries_op(m_entries, this->_emat,
[&](auto, auto, auto& me, auto delta, auto& edelta)
{
double ers = 0;
double xrs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_brec[0][me];
xrs = this->_brec[i][me];
}
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
dS -= -w_log_P(ers, xrs);
dS += -w_log_P(ers + d, xrs + dx);
if (ea.recs_dl)
{
size_t ers = 0;
if (me != _emat.get_null_edge())
ers = this->_mrs[me];
if (ers == 0 && delta > 0)
dB_E++;
if (ers > 0 && ers + delta == 0)
dB_E--;
}
});
if (dB_E != 0 && ea.recs_dl && std::isnan(_wparams[i][0])
&& std::isnan(_wparams[i][1]))
{
......@@ -758,60 +758,58 @@ public:
int dB_E_D = 0;
double dBx2 = 0;
_dBdx[i] = 0;
entries_op(m_entries, _emat,
[&](auto, auto, auto& me, auto, auto& edelta)
{
double ers = 0;
double xrs = 0, x2rs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_brec[0][me];
xrs = this->_brec[i][me];
x2rs = this->_bdrec[i][me];
}
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
auto dx2 = get<1>(edelta)[i];
dS -= -signed_w_log_P(ers, xrs, x2rs,
wp[0], wp[1],
wp[2], wp[3],
this->_epsilon[i]);
dS += -signed_w_log_P(ers + d,
xrs + dx,
x2rs + dx2,
wp[0], wp[1],
wp[2], wp[3],
this->_epsilon[i]);
if (std::isnan(wp[0]) &&
std::isnan(wp[1]))
{
auto n_ers = ers + get<0>(edelta)[0];
if (ers == 0 && n_ers > 0)
dB_E++;
if (ers > 0 && n_ers == 0)
dB_E--;
if (n_ers > 1)
{
if (ers < 2)
dB_E_D++;
_dBdx[i] += \
(x2rs + dx2 -
std::pow(xrs + dx, 2) / n_ers);
}
if (ers > 1)
{
if (n_ers < 2)
dB_E_D--;
_dBdx[i] -= \
(x2rs -
std::pow(xrs, 2) / ers);
}
dBx2 += (std::pow(xrs + dx, 2) -
std::pow(xrs, 2));
}
});
wentries_op(m_entries, _emat,
[&](auto, auto, auto& me, auto, auto& edelta)
{
double ers = 0;
double xrs = 0, x2rs = 0;
if (me != _emat.get_null_edge())
{
ers = this->_brec[0][me];
xrs = this->_brec[i][me];
x2rs = this->_bdrec[i][me];
}
auto d = get<0>(edelta)[0];
auto dx = get<0>(edelta)[i];
auto dx2 = get<1>(edelta)[i];
dS -= -signed_w_log_P(ers, xrs, x2rs,
wp[0], wp[1],
wp[2], wp[3],
this->_epsilon[i]);
dS += -signed_w_log_P(ers + d,
xrs + dx,
x2rs + dx2,
wp[0], wp[1],
wp[2], wp[3],
this->_epsilon[i]);
if (std::isnan(wp[0]) &&
std::isnan(wp[1]))
{
auto n_ers = ers + get<0>(edelta)[0];
if (ers == 0 && n_ers > 0)
dB_E++;
if (ers > 0 && n_ers == 0)
dB_E--;
if (n_ers > 1)
{
if (ers < 2)
dB_E_D++;
_dBdx[i] += \
(x2rs + dx2 -
std::pow(xrs + dx, 2) / n_ers);
}
if (ers > 1)
{
if (n_ers < 2)
dB_E_D--;
_dBdx[i] -= \
(x2rs -
std::pow(xrs, 2) / ers);
}
dBx2 += (std::pow(xrs + dx, 2) -
std::pow(xrs, 2));
}
});
if (std::isnan(wp[0]) && std::isnan(wp[1]))
{
......@@ -904,12 +902,12 @@ public:
{
auto& recs_entries = m_entries._recs_entries;
recs_entries.clear();
entries_op(m_entries, _emat,
[&](auto r, auto s, auto& me, auto delta, auto& edelta)
{
recs_entries.emplace_back(r, s, me, delta,
get<0>(edelta));
});
wentries_op(m_entries, _emat,
[&](auto r, auto s, auto& me, auto delta, auto& edelta)
{
recs_entries.emplace_back(r, s, me, delta,
get<0>(edelta));
});
scoped_lock lck(_lock);
dS_dl += _coupled_state->recs_dS(r, nr, recs_entries, _dBdx, dL);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment