Commit 8918329f authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: preserve pclabel in mcmc

parent 1d780746
......@@ -1719,21 +1719,29 @@ public:
void sample_branch(size_t v, size_t u, rng_t& rng)
{
size_t s;
auto r = _b[u];
std::bernoulli_distribution new_r(1./_candidate_blocks.size());
if (_candidate_blocks.size() <= num_vertices(_g) && new_r(rng))
{
get_empty_block(v);
s = uniform_sample(_empty_blocks, rng);
auto r = _b[u];
if (_coupled_state != nullptr)
{
_coupled_state->sample_branch(s, r, rng);
auto& hpclabel = _coupled_state->get_pclabel();
hpclabel[s] = _pclabel[u];
}
_bclabel[s] = _bclabel[r];
}
else
{
s = uniform_sample(_candidate_blocks.begin() + 1,
_candidate_blocks.end(), rng);
do
{
s = uniform_sample(_candidate_blocks.begin() + 1,
_candidate_blocks.end(), rng);
}
while(_bclabel[s] != _bclabel[r]);
}
_b[v] = s;
}
......@@ -1751,7 +1759,11 @@ public:
auto s = uniform_sample(_empty_blocks, rng);
auto r = _b[v];
if (_coupled_state != nullptr)
{
_coupled_state->sample_branch(s, r, rng);
auto& hpclabel = _coupled_state->get_pclabel();
hpclabel[s] = _pclabel[v];
}
_bclabel[s] = _bclabel[r];
return s;
}
......@@ -2419,7 +2431,13 @@ public:
partition_stats_t& get_partition_stats(size_t v)
{
return _partition_stats[_pclabel[v]];
size_t r = _pclabel[v];
if (r >= _partition_stats.size())
{
disable_partition_stats();
enable_partition_stats();
}
return _partition_stats[r];
}
void init_mcmc(double c, double dl)
......@@ -2490,6 +2508,11 @@ public:
return _b;
}
vprop_map_t<int32_t>::type::unchecked_t& get_pclabel()
{
return _pclabel;
}
bool check_edge_counts(bool emat=true)
{
gt_hash_map<std::pair<size_t, size_t>, size_t> mrs;
......
......@@ -69,6 +69,7 @@ public:
virtual double get_delta_partition_dl(size_t v, size_t r, size_t nr,
const entropy_args_t& ea) = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_b() = 0;
virtual vprop_map_t<int32_t>::type::unchecked_t& get_pclabel() = 0;
virtual bool check_edge_counts(bool emat=true) = 0;
virtual bool allow_move(size_t r, size_t nr) = 0;
};
......
......@@ -1065,6 +1065,11 @@ struct Layers
return BaseState::_b;
}
vprop_map_t<int32_t>::type::unchecked_t& get_pclabel()
{
return BaseState::_pclabel;
}
void sync_emat()
{
BaseState::sync_emat();
......
......@@ -1185,6 +1185,11 @@ public:
return _b;
}
vprop_map_t<int32_t>::type::unchecked_t& get_pclabel()
{
return _pclabel;
}
void init_mcmc(double c, double dl)
{
if (!std::isinf(c))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment