graph_blockmodel_multiflip_mcmc.hh 31.1 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2020 Tiago de Paula Peixoto <tiago@skewed.de>
4
//
5
6
7
8
// This program is free software; you can redistribute it and/or modify it under
// the terms of the GNU Lesser General Public License as published by the Free
// Software Foundation; either version 3 of the License, or (at your option) any
// later version.
9
//
10
11
12
13
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
// details.
14
//
15
// You should have received a copy of the GNU Lesser General Public License
16
17
18
19
20
21
22
23
24
25
26
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef GRAPH_BLOCKMODEL_MULTIFLIP_MCMC_HH
#define GRAPH_BLOCKMODEL_MULTIFLIP_MCMC_HH

#include "config.h"

#include <vector>
#include <algorithm>

#include "graph_tool.hh"
27
#include "../support/graph_state.hh"
28
29
30
31
32
33
34
35
36
37
38
39
40
#include "graph_blockmodel_util.hh"
#include <boost/mpl/vector.hpp>

namespace graph_tool
{
using namespace boost;
using namespace std;

#define MCMC_BLOCK_STATE_params(State)                                         \
    ((__class__,&, mpl::vector<python::object>, 1))                            \
    ((state, &, State&, 0))                                                    \
    ((beta,, double, 0))                                                       \
    ((c,, double, 0))                                                          \
41
    ((d,, double, 0))                                                          \
42
    ((psingle,, double, 0))                                                    \
43
    ((psplit,, double, 0))                                                     \
44
45
46
47
    ((pmerge,, double, 0))                                                     \
    ((pmergesplit,, double, 0))                                                \
    ((nproposal, &, vector<size_t>&, 0))                                       \
    ((nacceptance, &, vector<size_t>&, 0))                                     \
48
    ((gibbs_sweeps,, size_t, 0))                                               \
49
    ((oentropy_args,, python::object, 0))                                      \
50
    ((verbose,, int, 0))                                                       \
51
    ((force_move,, bool, 0))                                                   \
52
53
    ((niter,, size_t, 0))

54
enum class move_t { single = 0, split, merge, mergesplit, null };
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

template <class State>
struct MCMC
{
    GEN_STATE_BASE(MCMCBlockStateBase, MCMC_BLOCK_STATE_params(State))

    template <class... Ts>
    class MCMCBlockState
        : public MCMCBlockStateBase<Ts...>
    {
    public:
        GET_PARAMS_USING(MCMCBlockStateBase<Ts...>,
                         MCMC_BLOCK_STATE_params(State))
        GET_PARAMS_TYPEDEF(Ts, MCMC_BLOCK_STATE_params(State))

        template <class... ATs,
                  typename std::enable_if_t<sizeof...(ATs) ==
                                            sizeof...(Ts)>* = nullptr>
        MCMCBlockState(ATs&&... as)
           : MCMCBlockStateBase<Ts...>(as...),
            _g(_state._g),
76
77
78
79
            _groups(num_vertices(_state._bg)),
            _vpos(get(vertex_index_t(), _state._g),
                  num_vertices(_state._g)),
            _rpos(get(vertex_index_t(), _state._bg),
80
81
                  num_vertices(_state._bg)),
            _bnext(get(vertex_index_t(), _state._g),
82
83
                   num_vertices(_state._g)),
            _btemp(get(vertex_index_t(), _state._g),
84
85
                   num_vertices(_state._g)),
            _entropy_args(python::extract<typename State::_entropy_args_t&>(_oentropy_args))
86
        {
87
            _state.init_mcmc(*this);
88
89
            for (auto v : vertices_range(_state._g))
            {
90
                if (_state.node_weight(v) == 0)
91
                    continue;
92
93
                auto r = _state._b[v];
                add_element(_groups[r], _vpos, v);
94
                _N += _state.node_weight(v);
95
                _vertices.push_back(v);
96
            }
97

98
            for (auto r : vertices_range(_state._bg))
99
            {
100
                if (_groups[r].empty())
101
                    continue;
102
                add_element(_rlist, _rpos, r);
103
            }
104
105
106
107
108
109
110

            std::vector<move_t> moves
                = {move_t::single, move_t::split, move_t::merge,
                   move_t::mergesplit};
            std::vector<double> probs
                = {_psingle, _psplit, _pmerge, _pmergesplit};
            _move_sampler = Sampler<move_t, mpl::false_>(moves, probs);
111
112
113
114
        }

        typename state_t::g_t& _g;

115
        std::vector<size_t> _vertices;
116
117
        std::vector<std::vector<size_t>> _groups;
        typename vprop_map_t<size_t>::type::unchecked_t _vpos;
118
        size_t _nmoves = 0;
119

120
121
122
123
124
125
126
127
128
129
130
131
        std::vector<std::vector<std::tuple<size_t, size_t>>> _bstack;

        Sampler<move_t, mpl::false_> _move_sampler;

        void _push_b_dispatch() {}

        template <class... Vs>
        void _push_b_dispatch(const std::vector<size_t>& vs, Vs&&... vvs)
        {
            auto& back = _bstack.back();
            for (auto v : vs)
                back.emplace_back(v, _state._b[v]);
132
            _state.push_state(vs);
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            _push_b_dispatch(std::forward<Vs>(vvs)...);
        }

        template <class... Vs>
        void push_b(Vs&&... vvs)
        {
            _bstack.emplace_back();
            _push_b_dispatch(std::forward<Vs>(vvs)...);
        }

        void pop_b()
        {
            auto& back = _bstack.back();
            for (auto& vb : back)
            {
                size_t v = get<0>(vb);
                size_t s = get<1>(vb);
                move_vertex(v, s);
            }
            _bstack.pop_back();
153
            _state.pop_state();
154
        }
155

156
        std::vector<size_t> _rlist;
157
158
159
        std::vector<size_t> _rlist_split;
        typename vprop_map_t<size_t>::type::unchecked_t _rpos;

160
        std::vector<size_t> _vs;
161
        move_t _move;
162

163
164
        typename vprop_map_t<int>::type::unchecked_t _bnext;
        typename vprop_map_t<int>::type::unchecked_t _btemp;
165
        typename State::_entropy_args_t& _entropy_args;
166

167
        constexpr static size_t _null_move = 1;
168

169
170
        size_t _N = 0;

171
172
173
        double _dS;
        double _a;

174
        size_t node_state(size_t r)
175
        {
176
            return r;
177
178
        }

179
        constexpr bool skip_node(size_t)
180
        {
181
            return false;
182
183
        }

184
185
        template <bool sample_branch=true, class RNG, class VS = std::array<size_t,0>>
        size_t sample_new_group(size_t v, RNG& rng, VS&& except = VS())
186
        {
187
188
189
190
191
192
193
            _state.get_empty_block(v, except.size() >= _state._empty_blocks.size());
            size_t t;
            do
            {
                t = uniform_sample(_state._empty_blocks, rng);
            } while (!except.empty() &&
                     std::find(except.begin(), except.end(), t) != except.end());
194

195
            auto r = _state._b[v];
196
            _state._bclabel[t] = _state._bclabel[r];
197
            if (_state._coupled_state != nullptr)
198
199
200
            {
                if constexpr (sample_branch)
                {
201
202
203
204
205
                    do
                    {
                        _state._coupled_state->sample_branch(t, r, rng);
                    }
                    while(!_state.allow_move(r, t));
206
207
208
209
210
211
212
213
214
215
                }
                else
                {
                    auto& bh = _state._coupled_state->get_b();
                    bh[t] = bh[r];
                }
                auto& hpclabel = _state._coupled_state->get_pclabel();
                hpclabel[t] = _state._pclabel[v];
            }

216
217
218
219
220
221
222
223
            if (t >= _groups.size())
            {
                _groups.resize(t + 1);
                _rpos.resize(t + 1);
            }
            assert(_state._wr[t] == 0);
            return t;
        }
224

225
226
        void move_vertex(size_t v, size_t r)
        {
227
            size_t s = _state._b[v];
228
            _state.move_vertex(v, r);
229
230
            if (s == r)
                return;
231
232
            remove_element(_groups[s], _vpos, v);
            add_element(_groups[r], _vpos, v);
233
            _nmoves++;
234
        }
235

236
237
238
239
        template <class RNG>
        std::tuple<double, double>
        gibbs_sweep(std::vector<size_t>& vs, size_t r, size_t s,
                    double beta, RNG& rng)
240
        {
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
            double lp = 0, dS = 0;
            std::array<double,2> p = {0,0};
            std::shuffle(vs.begin(), vs.end(), rng);
            for (auto v : vs)
            {
                size_t bv = _state._b[v];
                size_t nbv = (bv == r) ? s : r;
                double ddS;
                if (_state.virtual_remove_size(v) > 0)
                    ddS = _state.virtual_move(v, bv, nbv, _entropy_args);
                else
                    ddS = std::numeric_limits<double>::infinity();

                if (!std::isinf(beta) && !std::isinf(ddS))
                {
                    double Z = log_sum(0., -ddS * beta);
                    p[0] = -ddS * beta - Z;
                    p[1] = -Z;
                }
                else
                {
                    if (ddS < 0)
                    {
                        p[0] = 0;
                        p[1] = -std::numeric_limits<double>::infinity();
                    }
                    else
                    {
                        p[0] = -std::numeric_limits<double>::infinity();;
                        p[1] = 0;
                    }
                }

                std::bernoulli_distribution sample(exp(p[0]));
                if (sample(rng))
                {
                    move_vertex(v, nbv);
                    lp += p[0];
                    dS += ddS;
                }
                else
                {
                    lp += p[1];
                }
            }
            return {dS, lp};
        }
288

289
        template <bool forward=true, class RNG>
290
        std::tuple<double, size_t, size_t>
291
        stage_split_random(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
292
        {
293
294
            std::array<size_t, 2> rt = {null_group, null_group};
            double dS = 0;
295
296

            std::uniform_real_distribution<> unit(0, 1);
297
298
            double p0 = unit(rng);
            std::bernoulli_distribution sample(p0);
299

300
301
            std::shuffle(vs.begin(), vs.end(), rng);
            for (auto v : vs)
302
303
304
            {
                if (rt[0] == null_group)
                {
305
                    rt[0] = r;
306
307
                    dS += _state.virtual_move(v, _state._b[v], rt[0],
                                              _entropy_args);
308
                    move_vertex(v, rt[0]);
309
310
                    continue;
                }
311

312
313
                if (rt[1] == null_group)
                {
314
315
                    if constexpr (forward)
                        rt[1] = (s == null_group) ? sample_new_group(v, rng) : s;
316
317
                    else
                        rt[1] = s;
318
319
                    dS += _state.virtual_move(v, _state._b[v], rt[1],
                                              _entropy_args);
320
                    move_vertex(v, rt[1]);
321
322
                    continue;
                }
323

324
325
                if (sample(rng))
                {
326
327
328
                    dS += _state.virtual_move(v, _state._b[v], rt[0],
                                              _entropy_args);
                    move_vertex(v, rt[0]);
329
330
331
                }
                else
                {
332
                    dS += _state.virtual_move(v, _state._b[v], rt[1],
333
                                              _entropy_args);
334
                    move_vertex(v, rt[1]);
335
336
                }
            }
337
338
            return {dS, rt[0], rt[1]};
        }
339

340
341
342
343
344
345
346
347
        template <bool forward=true, class RNG>
        std::tuple<double, size_t, size_t>
        stage_split_scatter(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
        {
            std::array<size_t, 2> rt = {null_group, null_group};
            std::array<double, 2> ps;
            double dS = 0;

348
            std::array<size_t, 2> except = {r, s};
349
350
            size_t t;
            if (_rlist.size() < (forward ? _N - 1 : _N))
351
                t = sample_new_group<false>(_groups[r].front(), rng, except);
352
353
354
            else
                t = r;

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            for (auto v : _groups[r])
            {
                dS += _state.virtual_move(v, _state._b[v], t,
                                          _entropy_args);
                move_vertex(v, t);
            }

            if constexpr (!forward)
            {
                for (auto v : _groups[s])
                {
                    dS += _state.virtual_move(v, _state._b[v], t,
                                              _entropy_args);
                    move_vertex(v, t);
                }
            }

            std::shuffle(vs.begin(), vs.end(), rng);
            for (auto v : vs)
            {
                if (rt[0] == null_group)
                {
                    rt[0] = r;
                    dS += _state.virtual_move(v, _state._b[v], rt[0],
                                              _entropy_args);
                    move_vertex(v, rt[0]);
                    continue;
                }
383

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
                if (rt[1] == null_group)
                {
                    if constexpr (forward)
                        rt[1] = (s == null_group) ? sample_new_group(v, rng) : s;
                    else
                        rt[1] = s;
                    dS += _state.virtual_move(v, _state._b[v], rt[1],
                                              _entropy_args);
                    move_vertex(v, rt[1]);
                    continue;
                }

                ps[0] = _state.virtual_move(v, _state._b[v], rt[0],
                                            _entropy_args);
                ps[1] = _state.virtual_move(v, _state._b[v], rt[1],
                                            _entropy_args);;

                double Z = log_sum(ps[0], ps[1]);
                double p0 = ps[0] - Z;
                std::bernoulli_distribution sample(exp(p0));
                if (sample(rng))
                {
                    dS += ps[0];
                    move_vertex(v, rt[0]);
                }
                else
                {
                    dS += ps[1];
                    move_vertex(v, rt[1]);
                }
            }
            return {dS, rt[0], rt[1]};
        }

        template <bool forward=true, class RNG>
        std::tuple<double, size_t, size_t>
        stage_split_coalesce(std::vector<size_t>& vs, size_t r, size_t s, RNG& rng)
        {
            std::array<size_t, 2> rt = {null_group, null_group};
            std::array<double, 2> ps;
            double dS = 0;

            size_t pos = 0;
            std::array<size_t, 2> except = {r, s};
428
429
430
431
432
433
434

            size_t nB = _groups[r].size();
            if constexpr (!forward)
                nB += _groups[s].size();
            if (_state._empty_blocks.size() < nB)
                _state.add_block(nB - _state._empty_blocks.size());

435
436
            for (auto v : _groups[r])
            {
437
438
439
440
441
                size_t t;
                if (_rlist.size() + pos < (forward ? _N - 1 : _N))
                    t = sample_new_group<false>(v, rng, except);
                else
                    t = r;
442
443
444
                dS += _state.virtual_move(v, _state._b[v], t,
                                          _entropy_args);
                move_vertex(v, t);
445
                ++pos;
446
447
448
449
450
451
            }

            if constexpr (!forward)
            {
                for (auto v : _groups[s])
                {
452
453
454
455
456
                    size_t t;
                    if (_rlist.size() + pos < (forward ? _N - 1 : _N))
                        t = sample_new_group<false>(v, rng, except);
                    else
                        t = s;
457
458
459
                    dS += _state.virtual_move(v, _state._b[v], t,
                                              _entropy_args);
                    move_vertex(v, t);
460
                    ++pos;
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                }
            }

            std::shuffle(vs.begin(), vs.end(), rng);
            for (auto v : vs)
            {
                if (rt[0] == null_group)
                {
                    rt[0] = r;
                    dS += _state.virtual_move(v, _state._b[v], rt[0],
                                              _entropy_args);
                    move_vertex(v, rt[0]);
                    continue;
                }

                if (rt[1] == null_group)
                {
                    if constexpr (forward)
                        rt[1] = (s == null_group) ? sample_new_group(v, rng) : s;
                    else
                        rt[1] = s;
                    dS += _state.virtual_move(v, _state._b[v], rt[1],
                                              _entropy_args);
                    move_vertex(v, rt[1]);
                    continue;
                }

                ps[0] = _state.virtual_move(v, _state._b[v], rt[0],
                                            _entropy_args);
                ps[1] = _state.virtual_move(v, _state._b[v], rt[1],
                                            _entropy_args);;

                double Z = log_sum(ps[0], ps[1]);
                double p0 = ps[0] - Z;
                std::bernoulli_distribution sample(exp(p0));
                if (sample(rng))
                {
                    dS += ps[0];
                    move_vertex(v, rt[0]);
                }
                else
                {
                    dS += ps[1];
                    move_vertex(v, rt[1]);
                }
            }
            return {dS, rt[0], rt[1]};
        }
509

510
        template <class RNG, bool forward=true>
511
        std::tuple<size_t, double, double> split(size_t r, size_t s, RNG& rng)
512
513
        {
            auto vs = _groups[r];
514

515
516
517
            if constexpr (!forward)
                vs.insert(vs.end(), _groups[s].begin(), _groups[s].end());

518
519
            double dS = 0;
            std::array<size_t, 2> rt = {null_group, null_group};
520

521
522
523
524
            std::uniform_int_distribution<int> stage_sample(0,2);
            switch (stage_sample(rng))
            {
            case 0:
525
                std::tie(dS, rt[0], rt[1]) = stage_split_random<forward>(vs, r, s, rng);
526
527
                break;
            case 1:
528
                std::tie(dS, rt[0], rt[1]) = stage_split_scatter<forward>(vs, r, s, rng);
529
530
531
532
533
534
535
                break;
            case 2:
                std::tie(dS, rt[0], rt[1]) = stage_split_coalesce<forward>(vs, r, s, rng);
                break;
            default:
                break;
            }
536
537
538
539
540
541
542

            for (size_t i = 0; i < _gibbs_sweeps - 1; ++i)
            {
                auto ret = gibbs_sweep(vs, rt[0], rt[1],
                                       (i < _gibbs_sweeps / 2) ? 1 : _beta,
                                       rng);
                dS += get<0>(ret);
543
            }
544

545
546
547
548
549
550
551
552
553
            double lp = 0;
            if constexpr (forward)
            {
                auto ret = gibbs_sweep(vs, rt[0], rt[1], _beta, rng);
                dS += get<0>(ret);
                lp = get<1>(ret);
            }

            return {rt[1], dS, lp};
554
555
556
        }

        template <class RNG>
557
        double split_prob(size_t r, size_t s, RNG& rng)
558
        {
559
560
            auto vs = _groups[r];
            vs.insert(vs.end(), _groups[s].begin(), _groups[s].end());
561

562
563
564
565
566
567
568
569
            push_b(vs);

            for (auto v : vs)
                _btemp[v] = _state._b[v];

            split<RNG, false>(r, s, rng);

            std::shuffle(vs.begin(), vs.end(), rng);
570

571
572
            double lp = 0;
            for (auto v : vs)
573
            {
574
575
576
577
578
579
580
                size_t bv = _state._b[v];
                size_t nbv = (bv == r) ? s : r;
                double ddS;
                if (_state.virtual_remove_size(v) > 0)
                    ddS = _state.virtual_move(v, bv, nbv, _entropy_args);
                else
                    ddS = std::numeric_limits<double>::infinity();
581

582
583
                size_t tbv = _btemp[v];

584
                if (!std::isinf(ddS))
585
                {
586
                    ddS *= _beta;
587
                    double Z = log_sum(0., -ddS);
588

589
590
591
592
593
594
595
596
597
                    if (tbv == nbv)
                    {
                        move_vertex(v, nbv);
                        lp += -ddS - Z;
                    }
                    else
                    {
                        lp += -Z;
                    }
598
599
600
                }
                else
                {
601
602
603
604
605
                    if (tbv == nbv)
                    {
                        lp = -std::numeric_limits<double>::infinity();
                        break;
                    }
606
                }
607
            }
608

609
            pop_b();
610

611
            return lp;
612
613
        }

614
615
        bool allow_merge(size_t r, size_t s)
        {
616
            return _state.allow_move(r, s);
617
618
        }

619
        double merge(size_t r, size_t s)
620
621
622
        {
            double dS = 0;

623
            auto vs = _groups[r];
624

625
            for (auto v : vs)
626
            {
627
628
629
                size_t bv = _state._b[v];
                dS +=_state.virtual_move(v, bv, s, _entropy_args);
                move_vertex(v, s);
630
631
            }

632
633
634
635
636
637
            return dS;
        }

        template <class RNG>
        size_t sample_move(size_t r, RNG& rng)
        {
638
            size_t v = uniform_sample(_groups[r], rng);
639
640
            auto s = r;
            while (s == r)
641
                s = _state.sample_block(v, _c, 0, rng);
642
            return s;
643
        }
644

645
646
        double get_move_prob(size_t r, size_t s)
        {
647
648
            double prs = 0;
            double prr = 0;
649
650
651
652
653
654
655
656
            for (auto v : _groups[r])
            {
                prs += _state.get_move_prob(v, r, s, _c, 0, false);
                prr += _state.get_move_prob(v, r, r, _c, 0, false);
            }
            prs /= _groups[r].size();
            prr /= _groups[r].size();
            return prs/(1-prr);
657
658
        }

659
        double merge_prob(size_t r, size_t s)
660
        {
661
662
663
664
665
666
667
668
669
            return log(get_move_prob(r, s));
        }

        template <class RNG>
        std::tuple<size_t, double, double, double>
        sample_merge(size_t r, RNG& rng)
        {
            size_t s = sample_move(r, rng);

670
            if (s == r || !allow_merge(r, s))
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                return {null_group, 0., 0., 0.};

            double pf = 0, pb = 0;
            if (!std::isinf(_beta))
            {
                pf = merge_prob(r, s);
                pb = split_prob(s, r, rng);
            }

            if (_verbose)
                cout << "merge " << _groups[r].size() << " " << _groups[s].size();

            double dS = merge(r, s);

            if (_verbose)
                cout << " " << dS << " " << pf << "  " << pb << endl;

            return {s, dS, pf, pb};
        }

        template <class RNG>
        std::tuple<size_t, double, double, double>
        sample_split(size_t r, size_t s, RNG& rng)
        {
            double dS, pf, pb=0;
            std::tie(s, dS, pf) = split(r, s, rng);
            if (!std::isinf(_beta))
                pb = merge_prob(s, r);

            if (_verbose)
                cout << "split " << _groups[r].size() << " " << _groups[s].size()
                     << " " << dS << " " << pf << " " << pb << endl;

            return {s, dS, pf, pb};
705
706
707
        }

        template <class RNG>
708
709
        std::tuple<size_t, size_t>
        move_proposal(size_t, RNG& rng)
710
711
        {
            double pf = 0, pb = 0;
712
            _dS = _a = 0;
713
714
            _vs.clear();
            _nmoves = 0;
715
            _state.clear_next_state();
716

717
            auto move = _move_sampler.sample(rng);
718
719

            switch (move)
720
            {
721
            case move_t::single:
722
                {
723
724
                    auto v = uniform_sample(_vertices, rng);
                    size_t r = _state._b[v];
725
726
                    auto s = _state.sample_block(v, _c, _d, rng);
                    if (s >= _groups.size())
727
                    {
728
729
730
731
732
733
734
735
                        _groups.resize(s + 1);
                        _rpos.resize(s + 1);
                    }
                    if (r == s || !_state.allow_move(r, s) ||
                        (_d == 0 && _groups[r].size() == 1 && !std::isinf(_beta)))
                    {
                        move = move_t::null;
                        break;
736
                    }
737
738
                    _dS = _state.virtual_move(v, r, s, _entropy_args);
                    if (!std::isinf(_beta))
739
                    {
740
741
                        pf = log(_state.get_move_prob(v, r, s, _c, _d, false));
                        pb = log(_state.get_move_prob(v, s, r, _c, _d, true));
742
                    }
743
744
745
746
                    _vs.clear();
                    _vs.push_back(v);
                    _bnext[v] = s;
                    _nmoves++;
747
                }
748
749
750
                break;

            case move_t::split:
751
                {
752
753
                    auto r = uniform_sample(_rlist, rng);

754
                    if (_groups[r].size() < 2)
755
756
757
758
                    {
                        move = move_t::null;
                        break;
                    }
759

760
761
                    _state._egroups_update = false;

762
763
764
765
766
767
768
                    _vs = _groups[r];
                    push_b(_vs);

                    size_t s;
                    std::tie(s, _dS, pf) = split(r, null_group, rng);

                    if (!std::isinf(_beta))
769
                    {
770
                        pf += log(_psplit);
771
772
773
774
775
776
                        pf += -safelog_fast(_rlist.size());

                        pb = merge_prob(s, r);
                        pb += -safelog_fast(_rlist.size()+1);

                        pb += log(_pmerge);
777
                    }
778

779
780
781
782
                    if (_verbose)
                        cout << "split proposal: " << _groups[r].size() << " "
                             << _groups[s].size() << " " << _dS << " " << pb - pf
                             << " " << -_dS + pb - pf << endl;
783

784
                    for (auto v : _vs)
785
                    {
786
                        _bnext[v] = _state._b[v];
787
788
                        _state.store_next_state(v);
                    }
789
                    pop_b();
790
791

                    _state._egroups_update = true;
792
793
                }
                break;
794

795
796
797
            case move_t::merge:
                {
                    if (_rlist.size() == 1)
798
799
800
801
802
                    {
                        move = move_t::null;
                        break;
                    }
                    auto r = uniform_sample(_rlist, rng);
803
804
                    auto s = sample_move(r, rng);
                    if (!allow_merge(r, s))
805
806
807
808
                    {
                        move = move_t::null;
                        break;
                    }
809

810
811
                    _state._egroups_update = false;

812
813
                    if (!std::isinf(_beta))
                    {
814
815
816
817
818
819
820
                        pf += log(_pmerge);
                        pf += -safelog_fast(_rlist.size());
                        pf += merge_prob(r, s);

                        pb = -safelog_fast(_rlist.size()-1);
                        pb += split_prob(s, r, rng);
                        pb += log(_psplit);
821
822
                    }

823
824
825
826
827
828
                    _vs = _groups[r];
                    push_b(_vs);

                    _dS = merge(r, s);

                    for (auto v : _vs)
829
                    {
830
                        _bnext[v] = _state._b[v];
831
832
                        _state.store_next_state(v);
                    }
833
834
                    pop_b();

835
836
                    _state._egroups_update = true;

837
                    if (_verbose)
838
839
                        cout << "merge proposal: " <<  _groups[r].size() << " "
                             << _groups[s].size() << " " << _dS << " " << pb - pf
840
                             << " " << -_dS + pb - pf << endl;
841
                }
842
843
844
                break;

            case move_t::mergesplit:
845
                {
846
                    if (_rlist.size() == 1)
847
848
849
850
851
852
                    {
                        move = move_t::null;
                        break;
                    }

                    size_t r = uniform_sample(_rlist, rng);
853

854
855
                    _state._egroups_update = false;

856
857
858
859
860
861
                    push_b(_groups[r]);

                    auto ret = sample_merge(r, rng);
                    size_t s = get<0>(ret);

                    if (s == null_group)
862
863
864
                    {
                        while (!_bstack.empty())
                            pop_b();
865
                        _state._egroups_update = true;
866
867
                        move = move_t::null;
                        break;
868
                    }
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886

                    _dS += get<1>(ret);
                    pf += get<2>(ret);
                    pb += get<3>(ret);

                    push_b(_groups[s]);

                    ret = sample_split(s, r, rng);
                    _dS += get<1>(ret);
                    pf += get<2>(ret);
                    pb += get<3>(ret);

                    for (auto& vs : _bstack)
                        for (auto& vb : vs)
                        {
                            auto v = get<0>(vb);
                            _vs.push_back(v);
                            _bnext[v] = _state._b[v];
887
                            _state.store_next_state(v);
888
889
890
891
892
                        }

                    while (!_bstack.empty())
                        pop_b();

893
894
                    _state._egroups_update = true;

895
896
897
                    if (_verbose)
                        cout << "mergesplit proposal: " << _dS << " " << pb - pf
                             << " " << -_dS + pb - pf << endl;
898
                }
899
900
901
                break;

            default:
902
                move = move_t::null;
903
                break;
904
            }
905

906
907
908
909
910
            if (move == move_t::null)
                return {_null_move, std::max(1, _nmoves)};

            _move = move;

911
            _a = pb - pf;
912

913
914
915
916
917
918
            if (size_t(move) >= _nproposal.size())
            {
                _nproposal.resize(size_t(move) + 1);
                _nacceptance.resize(size_t(move) + 1);
            }
            _nproposal[size_t(move)]++;
919

920
921
922
923
924
925
            if (_force_move)
            {
                _nmoves = std::numeric_limits<size_t>::max();
                _a = _dS * _beta + 1;
            }

926
            return {0, _nmoves};
927
928
929
        }

        std::tuple<double, double>
930
        virtual_move_dS(size_t, size_t)
931
932
        {
            return {_dS, _a};
933
934
        }

935
        void perform_move(size_t, size_t)
936
        {
937
            for (auto v : _vs)
938
            {
939
940
                size_t r = _state._b[v];
                size_t s = _bnext[v];
941
942
943
                if (r == s)
                    continue;

944
945
                if (_groups[s].empty())
                    add_element(_rlist, _rpos, s);
946

947
                move_vertex(v, s);
948

949
                if (_groups[r].empty())
950
                    remove_element(_rlist, _rpos, r);
951
            }
952

953
            _nacceptance[size_t(_move)]++;
954
        }
955

956
        constexpr bool is_deterministic()
957
        {
958
            return true;
959
960
        }

961
        constexpr bool is_sequential()
962
        {
963
            return false;
964
965
        }

966
        std::array<size_t, 1> _vlist = {0};
967
968
        auto& get_vlist()
        {
969
            return _vlist;
970
971
972
973
974
        }

        size_t get_N()
        {
            return _N;
975
976
977
978
979
980
981
982
983
        }

        double get_beta()
        {
            return _beta;
        }

        size_t get_niter()
        {
984
            return _niter;
985
986
        }

987
        constexpr void step(size_t, size_t)
988
989
        {
        }
990
991
992
993
994
995
    };
};

} // graph_tool namespace

#endif //GRAPH_BLOCKMODEL_MCMC_HH