Commit 30e6b8a0 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Allow for multiple simultaneous vertex movements in BlockState.move_vertex()

parent 2dfe4f33
......@@ -225,6 +225,8 @@ void export_blockmodel_state()
&state_t::remove_vertices;
void (state_t::*add_vertices)(python::object, python::object) =
&state_t::add_vertices;
void (state_t::*move_vertices)(python::object, python::object) =
&state_t::move_vertices;
double (state_t::*virtual_move)(size_t, size_t, bool, bool, bool,
bool, bool) =
&state_t::virtual_move;
......@@ -246,6 +248,7 @@ void export_blockmodel_state()
.def("remove_vertices", remove_vertices)
.def("add_vertices", add_vertices)
.def("move_vertex", &state_t::move_vertex)
.def("move_vertices", move_vertices)
.def("set_partition", set_partition)
.def("virtual_move", virtual_move)
.def("merge_vertices", merge_vertices)
......
......@@ -226,9 +226,7 @@ public:
void remove_vertices(python::object ovs)
{
vector<size_t> vs;
for (int i = 0; i < python::len(ovs); ++i)
vs.push_back(python::extract<size_t>(ovs[i]));
multi_array_ref<uint64_t, 1> vs = get_array<uint64_t, 1>(ovs);
remove_vertices(vs);
}
......@@ -400,13 +398,10 @@ public:
void add_vertices(python::object ovs, python::object ors)
{
vector<size_t> vs;
vector<size_t> rs;
for (int i = 0; i < python::len(ovs); ++i)
{
vs.push_back(python::extract<size_t>(ovs[i]));
rs.push_back(python::extract<size_t>(ors[i]));
}
multi_array_ref<uint64_t, 1> vs = get_array<uint64_t, 1>(ovs);
multi_array_ref<uint64_t, 1> rs = get_array<uint64_t, 1>(ors);
if (vs.size() != rs.size())
throw ValueException("vertex and group lists do not have the same size");
add_vertices(vs, rs);
}
......@@ -422,6 +417,22 @@ public:
add_vertex(v, nr);
}
template <class Vec>
void move_vertices(Vec& v, Vec& nr)
{
for (size_t i = 0; i < std::min(v.size(), nr.size()); ++i)
move_vertex(v[i], nr[i]);
}
void move_vertices(python::object ovs, python::object ors)
{
multi_array_ref<uint64_t, 1> vs = get_array<uint64_t, 1>(ovs);
multi_array_ref<uint64_t, 1> rs = get_array<uint64_t, 1>(ors);
if (vs.size() != rs.size())
throw ValueException("vertex and group lists do not have the same size");
move_vertices(vs, rs);
}
template <class VMap>
void set_partition(VMap&& b)
{
......
......@@ -452,15 +452,17 @@ void export_layered_blockmodel_state()
= &state_t::get_move_prob;
void (state_t::*merge_vertices)(size_t, size_t)
= &state_t::merge_vertices;
void (state_t::*set_partition)(boost::any&)
= &state_t::set_partition;
void (state_t::*move_vertices)(python::object, python::object) =
&state_t::move_vertices;
class_<state_t> c(name_demangle(typeid(state_t).name()).c_str(),
no_init);
c.def("remove_vertex", &state_t::remove_vertex)
.def("add_vertex", &state_t::add_vertex)
.def("move_vertex", &state_t::move_vertex)
.def("move_vertices", move_vertices)
.def("set_partition", set_partition)
.def("virtual_move", virtual_move)
.def("merge_vertices", merge_vertices)
......
......@@ -215,6 +215,22 @@ struct Layers
}
}
template <class Vec>
void move_vertices(Vec& v, Vec& nr)
{
for (size_t i = 0; i < std::min(v.size(), nr.size()); ++i)
move_vertex(v[i], nr[i]);
}
void move_vertices(python::object ovs, python::object ors)
{
multi_array_ref<uint64_t, 1> vs = get_array<uint64_t, 1>(ovs);
multi_array_ref<uint64_t, 1> rs = get_array<uint64_t, 1>(ors);
if (vs.size() != rs.size())
throw ValueException("vertex and group lists do not have the same size");
move_vertices(vs, rs);
}
void remove_vertex(size_t v)
{
size_t r = _b[v];
......
......@@ -81,12 +81,15 @@ void export_layered_overlap_blockmodel_state()
double (state_t::*get_move_prob)(size_t, size_t, size_t, double,
bool)
= &state_t::get_move_prob;
void (state_t::*move_vertices)(python::object, python::object) =
&state_t::move_vertices;
class_<state_t> c(name_demangle(typeid(state_t).name()).c_str(),
no_init);
c.def("remove_vertex", &state_t::remove_vertex)
.def("add_vertex", &state_t::add_vertex)
.def("move_vertex", &state_t::move_vertex)
.def("move_vertices", move_vertices)
.def("virtual_move", virtual_move)
.def("sample_block", sample_block)
.def("entropy", &state_t::entropy)
......
......@@ -159,15 +159,17 @@ void export_overlap_blockmodel_state()
double (state_t::*get_move_prob)(size_t, size_t, size_t, double,
bool)
= &state_t::get_move_prob;
void (state_t::*set_partition)(boost::any&)
= &state_t::set_partition;
void (state_t::*move_vertices)(python::object, python::object) =
&state_t::move_vertices;
class_<state_t> c(name_demangle(typeid(state_t).name()).c_str(),
no_init);
c.def("remove_vertex", &state_t::remove_vertex)
.def("add_vertex", &state_t::add_vertex)
.def("move_vertex", &state_t::move_vertex)
.def("move_vertices", move_vertices)
.def("set_partition", set_partition)
.def("virtual_move", virtual_move)
.def("sample_block", sample_block)
......
......@@ -217,6 +217,22 @@ public:
add_vertex(v, nr);
}
template <class Vec>
void move_vertices(Vec& v, Vec& nr)
{
for (size_t i = 0; i < std::min(v.size(), nr.size()); ++i)
move_vertex(v[i], nr[i]);
}
void move_vertices(python::object ovs, python::object ors)
{
multi_array_ref<uint64_t, 1> vs = get_array<uint64_t, 1>(ovs);
multi_array_ref<uint64_t, 1> rs = get_array<uint64_t, 1>(ors);
if (vs.size() != rs.size())
throw ValueException("vertex and group lists do not have the same size");
move_vertices(vs, rs);
}
template <class VMap>
void set_partition(VMap&& b)
{
......
......@@ -605,8 +605,16 @@ class BlockState(object):
edges_dl, partition_dl)
def move_vertex(self, v, s):
r"""Move vertex ``v`` to block ``s``."""
self._state.move_vertex(int(v), s)
r"""Move vertex ``v`` to block ``s``.
This optionally accepts a list of vertices and blocks to move
simultaneously.
"""
if not isinstance(v, collections.Iterable):
self._state.move_vertex(int(v), s)
else:
self._state.move_vertices(numpy.asarray(v, dtype="uint64"),
numpy.asarray(s, dtype="uint64"))
def remove_vertex(self, v):
r"""Remove vertex ``v`` from its current group.
......@@ -620,7 +628,7 @@ class BlockState(object):
twice.
"""
if isinstance(v, collections.Iterable):
self._state.remove_vertices(list(v))
self._state.remove_vertices(numpy.asarray(v, dtype="uint64"))
else:
self._state.remove_vertex(int(v))
......@@ -635,7 +643,8 @@ class BlockState(object):
added twice to the same group.
"""
if isinstance(v, collections.Iterable):
self._state.add_vertices(list(v), list(r))
self._state.add_vertices(numpy.asarray(v, dtype="uint64"),
numpy.asarray(r, dtype="uint64"))
else:
self._state.add_vertex(int(v), r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment