Commit 2f3b46a9 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: fix bug in detailed balance enforcement

parent 2b0a9da7
......@@ -1210,17 +1210,12 @@ public:
void set_move(size_t, size_t) {}
void insert_delta(size_t r, size_t s, int delta, bool source,
void insert_delta(size_t t, size_t s, int delta,
size_t mrs = numeric_limits<size_t>::max())
{
auto& entry = _entries[_pos];
if (source)
entry = make_pair(s, r);
else
entry = make_pair(r, s);
if (!is_directed::apply<Graph>::type::value &&
entry.second < entry.first)
std::swap(entry.first, entry.second);
if (!is_directed::apply<Graph>::type::value && (t > s))
std::swap(t, s);
_entries[_pos] = make_pair(t, s);
_delta[_pos] = delta;
_mrs[_pos] = mrs;
++_pos;
......@@ -1228,9 +1223,14 @@ public:
int get_delta(size_t t, size_t s)
{
auto& entry = _entries[0];
if (entry.first == t && entry.second == s)
return _delta[0];
if (!is_directed::apply<Graph>::type::value && (t > s))
std::swap(t, s);
for (size_t i = 0; i < 2; ++i)
{
auto& entry = _entries[i];
if (entry.first == t && entry.second == s)
return _delta[i];
}
return 0;
}
......@@ -1238,7 +1238,7 @@ public:
const std::array<pair<size_t, size_t>,2>& get_entries() { return _entries; }
const std::array<int, 2>& get_delta() { return _delta; }
std::array<size_t, 2>& get_mrs() { return _mrs; }
const std::array<size_t, 2>& get_mrs() { return _mrs; }
private:
size_t _pos;
......
......@@ -715,31 +715,61 @@ public:
_rnr = make_pair(r, nr);
}
void insert_delta(size_t r, size_t s, int delta, bool source,
void insert_delta(size_t t, size_t s, int delta,
size_t mrs = numeric_limits<size_t>::max())
{
if (s == _rnr.first || s == _rnr.second)
insert_delta(t, s, delta, mrs,
typename is_directed::apply<Graph>::type());
}
void insert_delta(size_t t, size_t s, int delta, size_t mrs, std::true_type)
{
bool src = false;
if (t != _rnr.first && t != _rnr.second)
{
if ((!is_directed::apply<Graph>::type::value && s < r) || source)
std::swap(r, s);
if (source)
source = false;
std::swap(t, s);
src = true;
}
if (source && (s == r))
source = false;
assert(t == _rnr.first || t == _rnr.second);
auto& r_field = (source) ? _r_field_s : _r_field_t;
auto& nr_field = (source) ? _nr_field_s : _nr_field_t;
auto& r_field = (src) ? _r_field_s : _r_field_t;
auto& nr_field = (src) ? _nr_field_s : _nr_field_t;
vector<size_t>& field = (_rnr.first == r) ? r_field : nr_field;
auto& field = (_rnr.first == t) ? r_field : nr_field;
if (field[s] == _null)
{
field[s] = _entries.size();
if ((!is_directed::apply<Graph>::type::value && s < r) || source)
_entries.emplace_back(s, r);
if (src)
_entries.emplace_back(s, t);
else
_entries.emplace_back(r, s);
_entries.emplace_back(t, s);
_delta.push_back(delta);
_mrs.push_back(mrs);
}
else
{
_delta[field[s]] += delta;
}
}
void insert_delta(size_t t, size_t s, int delta, size_t mrs, std::false_type)
{
if (t > s)
std::swap(t, s);
if (t != _rnr.first && t != _rnr.second)
std::swap(t, s);
assert(t == _rnr.first || t == _rnr.second);
auto& r_field = _r_field_t;
auto& nr_field = _nr_field_t;
auto& field = (_rnr.first == t) ? r_field : nr_field;
if (field[s] == _null)
{
field[s] = _entries.size();
_entries.emplace_back(t, s);
_delta.push_back(delta);
_mrs.push_back(mrs);
}
......@@ -757,15 +787,18 @@ public:
return get_delta_target(t, s);
if (s == _rnr.first || s == _rnr.second)
return get_delta_source(t, s);
return 0;
}
else
{
if (t > s)
std::swap(t, s);
if (t != _rnr.first && t != _rnr.second)
std::swap(t, s);
if (t == _rnr.first || t == _rnr.second)
return get_delta_target(t, s);
if (s == _rnr.first || s == _rnr.second)
return get_delta_target(s, t);
return 0;
}
return 0;
}
int get_delta_target(size_t r, size_t s)
......@@ -807,7 +840,7 @@ public:
const vector<pair<size_t, size_t> >& get_entries() { return _entries; }
const vector<int>& get_delta() { return _delta; }
vector<size_t>& get_mrs() { return _mrs; }
const vector<size_t>& get_mrs() { return _mrs; }
private:
static constexpr size_t _null = numeric_limits<size_t>::max();
......@@ -897,7 +930,7 @@ void remove_entries(Vertex v, Vertex r, Vprop& b, Eprop& eweights, CEprop& mrs,
const auto& me = bedge[e];
m_entries.insert_delta(r, s, -ew, false, mrs[me]);
m_entries.insert_delta(r, s, -ew, mrs[me]);
if (u == v || is_loop(v))
{
......@@ -919,7 +952,7 @@ void remove_entries(Vertex v, Vertex r, Vprop& b, Eprop& eweights, CEprop& mrs,
const auto& me = bedge[e];
m_entries.insert_delta(r, s, -ew, true, mrs[me]);
m_entries.insert_delta(s, r, -ew, mrs[me]);
}
}
......@@ -945,7 +978,7 @@ void add_entries(Vertex v, Vertex nr, Vprop& b, Eprop& eweights, Graph& g,
if (!is_directed::apply<Graph>::type::value)
self_weight += ew;
}
m_entries.insert_delta(nr, s, +ew, false);
m_entries.insert_delta(nr, s, +ew);
}
if (self_weight > 0 && self_weight % 2 == 0)
......@@ -958,7 +991,7 @@ void add_entries(Vertex v, Vertex nr, Vprop& b, Eprop& eweights, Graph& g,
continue;
vertex_t s = b[u];
int ew = eweights[e];
m_entries.insert_delta(nr, s, +ew, true);
m_entries.insert_delta(s, nr, +ew);
}
}
......@@ -969,7 +1002,7 @@ double entries_dS(MEntries& m_entries, Eprop& mrs, EMat& emat, BGraph& bg)
{
const auto& entries = m_entries.get_entries();
const auto& delta = m_entries.get_delta();
auto& d_mrs = m_entries.get_mrs();
const auto& d_mrs = m_entries.get_mrs();
double dS = 0;
for (size_t i = 0; i < entries.size(); ++i)
......@@ -978,7 +1011,7 @@ double entries_dS(MEntries& m_entries, Eprop& mrs, EMat& emat, BGraph& bg)
auto er = entry.first;
auto es = entry.second;
int d = delta[i];
size_t& ers = d_mrs[i];
size_t ers = d_mrs[i];
if (ers == numeric_limits<size_t>::max())
ers = get_mrs(er, es, mrs, emat); // slower
assert(int(ers) + d >= 0);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment