Commit f8ffd9c1 authored by Tiago Peixoto's avatar Tiago Peixoto

inference: Fix issues with reconstruction from epidemic dynamics

parent 3512df30
......@@ -103,6 +103,19 @@ public:
_T.push_back(T);
}
reset_m(s);
_m_temp.resize(_s.size());
};
template <class State>
void reset_m(State& s)
{
for (auto v : vertices_range(s._u))
{
for (auto& m : _m)
m[v].clear();
}
auto xc = s._x.get_checked();
for (auto v : vertices_range(s._u))
{
......@@ -145,8 +158,7 @@ public:
}
}
_m_temp.resize(_s.size());
};
}
template <class State>
bool check_m(State& s, size_t v)
......@@ -597,9 +609,7 @@ public:
if (n == -1 ||
python::extract<double>(params["r"]).check())
{
_beta.resize(_s.size());
_r.resize(_s.size());
_log_p.resize(_s.size());
for (size_t n = 0; n < _s.size(); ++n)
set_params(params, n);
......@@ -635,11 +645,6 @@ public:
_r[n] = python::extract<double>(params["r"][n]);
}
double get_log_P(size_t m, double r, double beta)
{
return r + (1-r) * (1 - std::pow(1-beta, m));
}
double log_P(size_t v, size_t n, double m, int s, int ns)
{
if (s != State::S)
......@@ -667,15 +672,12 @@ public:
hmap_t::unchecked_t _r_v;
private:
std::vector<double> _beta;
std::vector<double> _r;
std::vector<typename amap_t::unchecked_t> _active;
bool _has_r_v;
bool _exposed;
int _E;
size_t _N;
std::vector<std::vector<std::pair<double, double>>> _log_p;
};
template <class T>
......
......@@ -101,7 +101,12 @@ void export_epidemics_state()
get_xedges_prob(state, edges, probs, ea,
epsilon);
})
.def("set_params", &state_t::set_params);
.def("set_params", &state_t::set_params)
.def("reset_m",
+[](state_t& state)
{
state._dstate.reset_m(state);
});
});
});
......
......@@ -1006,6 +1006,11 @@ class EpidemicsBlockState(DynamicsBlockStateBase):
state["global_beta"] = beta
self.__init__(**state, beta=beta)
def copy(self, **kwargs):
"""Return a copy of the state."""
return type(self)(**dict(self.__getstate__(),
**dict(kwargs, beta=kwargs.get("beta", None))))
def set_params(self, params):
r"""Sets the model parameters via the dictionary ``params``."""
self.params = dict(self.params, **params)
......@@ -1013,6 +1018,7 @@ class EpidemicsBlockState(DynamicsBlockStateBase):
beta = self.params["global_beta"]
if beta is not None:
self.x.fa = log1p(-beta)
self._state.reset_m()
def get_x(self):
"""Return latent edge transmission probabilities."""
......@@ -1040,10 +1046,10 @@ class EpidemicsBlockState(DynamicsBlockStateBase):
"""
return super(EpidemicBlockState, self).mcmc_sweep(r=r, p=p, pstep=p,
h=h, hstep=hstep,
xstep=xstep,
multiflip=multiflip)
return super(EpidemicsBlockState, self).mcmc_sweep(r=r, p=p, pstep=p,
h=h, hstep=hstep,
xstep=xstep,
multiflip=multiflip)
def _algo_sweep(self, algo, r=.5, p=.1, pstep=.1, h=.1, hstep=1,
xstep=.1, niter=1, **kwargs):
......@@ -1066,8 +1072,7 @@ class EpidemicsBlockState(DynamicsBlockStateBase):
ret = self._move_proposal("global_beta",
kwargs.get("beta", 1),
pstep, (0, 1),
(lambda beta: log1p(-beta),
lambda x: 1-exp(x)),
None,
kwargs.get("entropy_args", {}))
dS += ret[0]
nt += ret[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment