graph_blockmodel_multiflip_mcmc.cc 3.96 KB
Newer Older
1 2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2018 Tiago de Paula Peixoto <tiago@skewed.de>
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#include "graph_tool.hh"
#include "random.hh"

#include <boost/python.hpp>

#include "graph_blockmodel_util.hh"
#include "graph_blockmodel.hh"
#include "graph_blockmodel_multiflip_mcmc.hh"
26
#include "../loops/mcmc_loop.hh"
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

using namespace boost;
using namespace graph_tool;

GEN_DISPATCH(block_state, BlockState, BLOCK_STATE_params)

template <class State>
GEN_DISPATCH(mcmc_block_state, MCMC<State>::template MCMCBlockState,
             MCMC_BLOCK_STATE_params(State))

python::object do_multiflip_mcmc_sweep(python::object omcmc_state,
                                       python::object oblock_state,
                                       rng_t& rng)
{
    python::object ret;
    auto dispatch = [&](auto& block_state)
    {
        typedef typename std::remove_reference<decltype(block_state)>::type
            state_t;

        mcmc_block_state<state_t>::make_dispatch
           (omcmc_state,
            [&](auto& s)
            {
                auto ret_ = mcmc_sweep(s, rng);
52
                ret = tuple_apply([&](auto&... args){ return python::make_tuple(args...); }, ret_);
53 54 55 56 57 58
            });
    };
    block_state::dispatch(oblock_state, dispatch);
    return ret;
}

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
class MCMC_sweep_base
{
public:
    virtual std::tuple<double, size_t, size_t> run(rng_t&) = 0;
};

template <class State>
class MCMC_sweep : public MCMC_sweep_base
{
public:
    MCMC_sweep(State& s) : _s(s) {}

    virtual std::tuple<double, size_t, size_t> run(rng_t& rng)
    {
        return mcmc_sweep(_s, rng);
    }
private:
    State _s;
};

python::object do_multiflip_mcmc_sweep_parallel(python::object omcmc_states,
                                                python::object oblock_states,
                                                rng_t& rng)
{
    std::vector<std::shared_ptr<MCMC_sweep_base>> sweeps;

    size_t N = python::len(omcmc_states);
    for (size_t i = 0; i < N; ++ i)
    {
        auto dispatch = [&](auto& block_state)
        {
            typedef typename std::remove_reference<decltype(block_state)>::type
                state_t;

            mcmc_block_state<state_t>::make_dispatch
               (omcmc_states[i],
                [&](auto& s)
                {
                    typedef typename std::remove_reference<decltype(s)>::type
                        s_t;
                    sweeps.push_back(std::make_shared<MCMC_sweep<s_t>>(s));
                });
        };
        block_state::dispatch(oblock_states[i], dispatch);
    }

105
    parallel_rng<rng_t>::init(rng);
106 107 108 109 110 111

    std::vector<std::tuple<double, size_t, size_t>> rets(N);

    #pragma omp parallel for schedule(runtime)
    for (size_t i = 0; i < N; ++i)
    {
112
        auto& rng_ = parallel_rng<rng_t>::get(rng);
113 114 115 116 117 118 119 120 121
        rets[i] = sweeps[i]->run(rng_);
    }

    python::list orets;
    for (auto& ret : rets)
        orets.append(tuple_apply([&](auto&... args){ return python::make_tuple(args...); }, ret));
    return orets;
}

122 123 124 125 126 127 128 129
namespace graph_tool
{
std::ostream& operator<<(std::ostream& os, move_t move)
{
    return os << static_cast<int>(move);
}
}

130 131 132 133
void export_blockmodel_multiflip_mcmc()
{
    using namespace boost::python;
    def("multiflip_mcmc_sweep", &do_multiflip_mcmc_sweep);
134
    def("multiflip_mcmc_sweep_parallel", &do_multiflip_mcmc_sweep_parallel);
135
}