Commit 5369eb34 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: Improve caching of lgamma()/log()/xlogx()

parent a3c332d0
......@@ -1643,8 +1643,8 @@ public:
{
size_t N_B_E_D = _B_E_D + dB_E_D;
dS_dl -= -safelog(_B_E_D);
dS_dl += -safelog(N_B_E_D);
dS_dl -= -safelog_fast(_B_E_D);
dS_dl += -safelog_fast(N_B_E_D);
_dBdx[i] = _recdx[i] * dB_E_D + _dBdx[i] * N_B_E_D;
......@@ -1912,8 +1912,8 @@ public:
dS += -positive_w_log_P(L + dL + ddL,
_Lrecdx[i+1] + dx, wp[2],
wp[3], _epsilon[i]);
dS -= -safelog(_B_E_D);
dS += -safelog(N_B_E_D);
dS -= -safelog_fast(_B_E_D);
dS += -safelog_fast(N_B_E_D);
}
}
}
......
......@@ -88,9 +88,9 @@ inline double vterm_exact(size_t mrp, size_t mrm, size_t wr, bool deg_corr,
else
{
if (graph_tool::is_directed(g))
return (mrp + mrm) * safelog(wr);
return (mrp + mrm) * safelog_fast(wr);
else
return mrp * safelog(wr);
return mrp * safelog_fast(wr);
}
}
......@@ -101,7 +101,7 @@ inline double eterm(size_t r, size_t s, size_t mrs, const Graph& g)
if (!graph_tool::is_directed(g) && r == s)
mrs *= 2;
double val = xlogx(mrs);
double val = xlogx_fast(mrs);
if (graph_tool::is_directed(g) || r != s)
return -val;
......@@ -120,9 +120,9 @@ inline double vterm(size_t mrp, size_t mrm, size_t wr, bool deg_corr,
one = 1;
if (deg_corr)
return one * (xlogx(mrm) + xlogx(mrp));
return one * (xlogx_fast(mrm) + xlogx_fast(mrp));
else
return one * (mrm * safelog(wr) + mrp * safelog(wr));
return one * (mrm * safelog_fast(wr) + mrp * safelog_fast(wr));
}
......@@ -132,11 +132,10 @@ inline double vterm(size_t mrp, size_t mrm, size_t wr, bool deg_corr,
// "edge" term of the entropy
template <class Graph>
inline double eterm_dense(size_t r, size_t s, int ers, double wr_r,
double wr_s, bool multigraph, const Graph& g)
inline double eterm_dense(size_t r, size_t s, uint64_t ers, uint64_t wr_r,
uint64_t wr_s, bool multigraph, const Graph& g)
{
// we should not use integers here, since they may overflow
double nrns;
uint64_t nrns; // avoid overflow for nr < 2^32
if (ers == 0)
return 0.;
......@@ -157,9 +156,9 @@ inline double eterm_dense(size_t r, size_t s, int ers, double wr_r,
double S;
if (multigraph)
S = lbinom(nrns + ers - 1, ers); // do not use lbinom_fast!
S = lbinom_fast<false>(nrns + ers - 1, ers); // do not use lbinom_fast<true>!
else
S = lbinom(nrns, ers);
S = lbinom_fast<false>(nrns, ers);
return S;
}
......@@ -168,7 +167,7 @@ template <class Graph>
double get_edges_dl(size_t B, size_t E, Graph& g)
{
size_t NB = (graph_tool::is_directed(g)) ? B * B : (B * (B + 1)) / 2;
return lbinom(NB + E - 1, E);
return lbinom_fast<false>(NB + E - 1, E);
}
} // namespace graph_tool
......
......@@ -135,7 +135,7 @@ public:
S += lgamma_fast(_N + 1);
for (auto nr : _total)
S -= lgamma_fast(nr + 1);
S += safelog(_N);
S += safelog_fast(_N);
return S;
}
......@@ -147,10 +147,10 @@ public:
size_t total = 0;
for (auto& k_c : _hist[r])
{
S -= xlogx(k_c.second);
S -= xlogx_fast(k_c.second);
total += k_c.second;
}
S += xlogx(total);
S += xlogx_fast(total);
}
return S;
}
......@@ -258,8 +258,8 @@ public:
if (dN != 0)
{
S_b += safelog(_N);
S_a += safelog(_N + dN);
S_b += safelog_fast(_N);
S_a += safelog_fast(_N + dN);
}
return S_a - S_b;
......@@ -305,8 +305,8 @@ public:
return (B * (B + 1)) / 2;
};
S_b += lbinom(get_x(actual_B) + _E - 1, _E);
S_a += lbinom(get_x(actual_B + dB) + _E - 1, _E);
S_b += lbinom_fast<false>(get_x(actual_B) + _E - 1, _E);
S_a += lbinom_fast<false>(get_x(actual_B + dB) + _E - 1, _E);
}
return S_a - S_b;
......@@ -379,7 +379,7 @@ public:
if (iter != _hist[s].end())
nd = iter->second;
assert(nd + delta >= 0);
return -xlogx(nd + delta);
return -xlogx_fast(nd + delta);
};
double S_b = 0, S_a = 0;
......@@ -393,8 +393,8 @@ public:
S_a += get_Sk(r, deg, diff * nk);
});
S_b += xlogx(nr);
S_a += xlogx(nr + dn);
S_b += xlogx_fast(nr);
S_a += xlogx_fast(nr + dn);
return S_a - S_b;
}
......@@ -406,8 +406,8 @@ public:
auto get_Se = [&](int dn, int dkin, int dkout)
{
double S = 0;
S += lbinom(_total[r] + dn + _ep[r] - 1 + dkout, _ep[r] + dkout);
S += lbinom(_total[r] + dn + _em[r] - 1 + dkin, _em[r] + dkin);
S += lbinom_fast(_total[r] + dn + _ep[r] - 1 + dkout, _ep[r] + dkout);
S += lbinom_fast(_total[r] + dn + _em[r] - 1 + dkin, _em[r] + dkin);
return S;
};
......
......@@ -849,8 +849,8 @@ public:
{
size_t N_B_E_D = _B_E_D + dB_E_D;
dS_dl -= -safelog(_B_E_D);
dS_dl += -safelog(N_B_E_D);
dS_dl -= -safelog_fast(_B_E_D);
dS_dl += -safelog_fast(N_B_E_D);
_dBdx[i] = _recdx[i] * dB_E_D + _dBdx[i] * N_B_E_D;
......
......@@ -253,9 +253,9 @@ struct overlap_partition_stats_t
size_t n_bv = _bhist.find(bv)->second;
S += xlogx(n_bv);
S += xlogx_fast(n_bv);
for (auto& dh : cdeg_hist)
S -= xlogx(dh.second);
S -= xlogx_fast(dh.second);
}
return S;
}
......
......@@ -35,7 +35,7 @@ void init_safelog(size_t x)
{
__safelog_cache.resize(x + 1);
for (size_t i = old_size; i < __safelog_cache.size(); ++i)
__safelog_cache[i] = safelog(double(i));
__safelog_cache[i] = safelog(i);
}
}
}
......@@ -55,7 +55,7 @@ void init_xlogx(size_t x)
{
__xlogx_cache.resize(x + 1);
for (size_t i = old_size; i < __xlogx_cache.size(); ++i)
__xlogx_cache[i] = i * safelog(i);
__xlogx_cache[i] = xlogx(i);
}
}
}
......
......@@ -38,38 +38,58 @@ extern vector<double> __lgamma_cache;
void init_safelog(size_t x);
template <class Type>
inline double safelog(Type x)
inline double safelog(auto x)
{
if (x == 0)
return 0;
return log(x);
}
inline double safelog(size_t x)
template <bool Init=true>
inline double safelog_fast(auto x)
{
if (x >= __safelog_cache.size())
init_safelog(x);
if (size_t(x) >= __safelog_cache.size())
{
if (Init)
init_safelog(x);
else
return safelog(x);
}
return __safelog_cache[x];
}
void init_xlogx(size_t x);
inline double xlogx(size_t x)
inline double xlogx(auto x)
{
//return x * safelog(x);
if (x >= __xlogx_cache.size())
init_xlogx(x);
return x * safelog(x);
}
template <bool Init=true>
inline double xlogx_fast(auto x)
{
if (size_t(x) >= __xlogx_cache.size())
{
if (Init)
init_xlogx(x);
else
return xlogx(x);
}
return __xlogx_cache[x];
}
void init_lgamma(size_t x);
inline double lgamma_fast(size_t x)
template <bool Init=true>
inline double lgamma_fast(auto x)
{
//return lgamma(x);
if (x >= __lgamma_cache.size())
init_lgamma(x);
if (size_t(x) >= __lgamma_cache.size())
{
if (Init)
init_lgamma(x);
else
return lgamma(x);
}
return __lgamma_cache[x];
}
......
......@@ -29,7 +29,7 @@ namespace graph_tool
{
using namespace boost;
inline double lbinom(double N, double k)
inline double lbinom(auto N, auto k)
{
if (N == 0 || k == 0 || k >= N)
return 0;
......@@ -38,14 +38,15 @@ inline double lbinom(double N, double k)
return ((lgamma(N + 1) - lgamma(k + 1)) - lgamma(N - k + 1));
}
inline double lbinom_fast(int N, int k)
template <bool Init=true>
inline double lbinom_fast(auto N, auto k)
{
if (N == 0 || k == 0 || k > N)
return 0;
return ((lgamma_fast(N + 1) - lgamma_fast(k + 1)) - lgamma_fast(N - k + 1));
return ((lgamma_fast<Init>(N + 1) - lgamma_fast<Init>(k + 1)) - lgamma_fast<Init>(N - k + 1));
}
inline double lbinom_careful(double N, double k)
inline double lbinom_careful(auto N, auto k)
{
if (N == 0 || k == 0 || k >= N)
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment