graph_blockmodel_overlap_partition.hh 28 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2017 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

18
19
#ifndef GRAPH_BLOCKMODEL_OVERLAP_PARTITION_HH
#define GRAPH_BLOCKMODEL_OVERLAP_PARTITION_HH
20

21
#include <functional>
22
#include "graph_blockmodel_overlap_util.hh"
23
#include "boost/container/small_vector.hpp"
24

25
26
27
28
29
namespace std
{

template <class Value, size_t N>
struct hash<boost::container::small_vector<Value, N>>
30
{
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    size_t operator()(const boost::container::small_vector<Value, N>& v) const
    {
        std::size_t seed = 0;
        for (auto& x : v)
            _hash_combine(seed, x);
        return seed;
    }
};
}

template <class Value, size_t N>
struct empty_key<boost::container::small_vector<Value, N>>
{
    static boost::container::small_vector<Value, N> get()
    {
        boost::container::small_vector<Value, N> key(1);
        key[0] = empty_key<Value>::get();
        return key;
    }
};
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
template <class Value, size_t N>
struct deleted_key<boost::container::small_vector<Value, N>>
{
    static boost::container::small_vector<Value, N> get()
    {
        boost::container::small_vector<Value, N> key(1);
        key[0] = deleted_key<Value>::get();
        return key;
    }
};


namespace graph_tool
{
66
67
68
69
70
71
72
73

//=============================
// Partition Description length
//=============================

struct overlap_partition_stats_t
{
    typedef std::tuple<int, int> deg_t;
74
    typedef boost::container::small_vector<deg_t, 64> cdeg_t;
75

76
    typedef boost::container::small_vector<int, 64> bv_t;
77

78
79
    typedef gt_hash_map<bv_t, size_t> bhist_t;
    typedef gt_hash_map<cdeg_t, size_t, std::hash<cdeg_t>> cdeg_hist_t;
80

81
    typedef gt_hash_map<bv_t, cdeg_hist_t> deg_hist_t;
82

83
    typedef gt_hash_map<bv_t, vector<size_t>> ebhist_t;
84

85
    typedef gt_hash_map<int, int> dmap_t;
86

87
88
89
90
    template <class Graph, class Vprop, class Eprop, class Vlist>
    overlap_partition_stats_t(Graph& g, Vprop& b, Vlist& vlist, size_t E,
                              size_t B, Eprop& eweight, overlap_stats_t& ostats,
                              std::vector<size_t>& bmap,
91
92
93
                              std::vector<size_t>& vmap,
                              bool allow_empty)
        : _overlap_stats(ostats), _bmap(bmap), _vmap(vmap),
94
95
96
          _allow_empty(allow_empty),
          _directed(is_directed::apply<Graph>::type::value)

97
    {
98
99
100
101
102
103
        _D = 0;
        _N = vlist.size();
        _E = E;
        _total_B = B;
        _dhist.resize(1);

104
        dmap_t in_hist, out_hist;
105
        for (size_t v : vlist)
106
        {
107
108
            auto nv = get_v(v);

109
110
111
            dmap_t in_hist, out_hist;
            set<size_t> rs;

112
            get_bv_deg(v, b, eweight, g, rs, in_hist, out_hist);
113
114
115
116
117
118
119
120
121
122

            cdeg_t cdeg;
            for (auto r : rs)
            {
                deg_t deg = std::make_tuple(in_hist[r], out_hist[r]);
                cdeg.push_back(deg);
            }

            bv_t bv(rs.begin(), rs.end());

123
124
125
126
            assert(bv.size() > 0);

            _bvs[nv] = bv;
            _degs[nv] = cdeg;
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

            auto & cdh = _deg_hist[bv];
            cdh[cdeg]++;

            size_t d = bv.size();
            _D = max(_D, d);
            _dhist[d]++;
            _bhist[bv]++;

            auto& bmh = _embhist[bv];
            auto& bph = _epbhist[bv];
            bmh.resize(bv.size());
            bph.resize(bv.size());

            for (size_t i = 0; i < bv.size(); ++i)
            {
                size_t r = bv[i];
                _emhist[r] += get<0>(cdeg[i]);
                _ephist[r] += get<1>(cdeg[i]);
                bmh[i] += get<0>(cdeg[i]);
                bph[i] += get<1>(cdeg[i]);
            }
        }

        for (auto& bv_c : _bhist)
        {
            assert(bv_c.second > 0);
            for (auto r : bv_c.first)
155
                _r_count[r] += 1;
156
157
        }

158
159
160
161
162
163
164
165
166
167
168
169
170
        _actual_B = _r_count.size();
    }

    size_t get_v(size_t v)
    {
        constexpr size_t null =
            std::numeric_limits<size_t>::max();
        if (v >= _vmap.size())
            _vmap.resize(v + 1, null);
        size_t nv = _vmap[v];
        if (nv == null)
            nv = _vmap[v] = _bvs.size();
        if (nv >= _bvs.size())
171
        {
172
173
            _bvs.resize(nv + 1);
            _degs.resize(nv + 1);
174
        }
175
        return nv;
176
177
    }

178
    size_t get_r(size_t r)
179
    {
180
181
182
183
184
185
186
187
        constexpr size_t null =
            std::numeric_limits<size_t>::max();
        if (r >= _bmap.size())
            _bmap.resize(r + 1, null);
        size_t nr = _bmap[r];
        if (nr == null)
            nr = _bmap[r] = _r_count.size();
        if (nr >= _r_count.size())
188
        {
189
190
191
192
            _r_count.resize(nr + 1);
            _dhist.resize(nr + 2);
            _emhist.resize(nr + 1);
            _ephist.resize(nr + 1);
193
        }
194
195
196
197
198
199
200
201
202
203
        return nr;
    }


    template <class Graph, class Vprop, class Eprop>
    void get_bv_deg(size_t v, Vprop& b, Eprop&, Graph& g, set<size_t>& rs,
                    dmap_t& in_hist, dmap_t& out_hist)
    {
        auto& half_edges = _overlap_stats.get_half_edges(v);
        for (size_t u : half_edges)
204
        {
205
206
            size_t kin = in_degreeS()(u, g);
            size_t kout = out_degreeS()(u, g);
207

208
209
210
            auto r = get_r(b[u]);
            in_hist[r] += kin;
            out_hist[r] += kout;
211
212
213
214
215
216
        }

        for (auto& rk : in_hist)
            rs.insert(rk.first);
    }

217
    double get_partition_dl() const
218
219
220
221
222
223
224
    {
        double S = 0;
        for (size_t d = 1; d < _dhist.size(); ++d)
        {
            size_t nd = _dhist[d];
            if (nd == 0)
                continue;
225
226
227
228
229
            double x;
            if (_allow_empty)
                x = lbinom_fast(_total_B, d);
            else
                x = lbinom_fast(_actual_B, d);
230
            double ss = lbinom_careful((exp(x) + nd) - 1, nd); // not fast
231
232
233
234
235
236
237
            if (std::isinf(ss) || std::isnan(ss))
                ss = nd * x - lgamma_fast(nd + 1);
            assert(!std::isinf(ss));
            assert(!std::isnan(ss));
            S += ss;
        }

238
        S += lbinom_fast(_D + _N - 1, _N) + lgamma_fast(_N + 1);
239
240

        for (auto& bh : _bhist)
241
            S -= lgamma_fast(bh.second + 1);
242
243
244
245

        return S;
    }

246
    double get_deg_dl_ent() const
247
248
    {
        double S = 0;
249
        for (auto& ch : _deg_hist)
250
        {
251
252
            auto& bv = ch.first;
            auto& cdeg_hist = ch.second;
253

254
            size_t n_bv = _bhist.find(bv)->second;
255

256
257
258
            S += xlogx(n_bv);
            for (auto& dh : cdeg_hist)
                S -= xlogx(dh.second);
259
        }
260
261
262
263
264
265
266
267
        return S;
    }

    double get_deg_dl_uniform() const
    {
        double S = 0;

        for (auto& ch : _deg_hist)
268
        {
269
270
            auto& bv = ch.first;
            size_t n_bv = _bhist.find(bv)->second;
271

272
273
            if (n_bv == 0)
                continue;
274

275
276
            const auto& bmh = _embhist.find(bv)->second;
            const auto& bph = _epbhist.find(bv)->second;
277

278
279
280
281
282
283
            for (size_t i = 0; i < bv.size(); ++i)
            {
                S += lbinom(n_bv + bmh[i] - 1, bmh[i]);
                S += lbinom(n_bv + bph[i] - 1, bph[i]);
            }
        }
284

285
286
287
288
289
290
291
        for (size_t r = 0; r < _r_count.size(); ++r)
        {
            if (_r_count[r] == 0)
                continue;
            S += lbinom(_r_count[r] + _emhist[r] - 1,  _emhist[r]);
            S += lbinom(_r_count[r] + _ephist[r] - 1,  _ephist[r]);
        }
292

293
294
        return S;
    }
295

296
297
298
299
300
301
302
    double get_deg_dl_dist() const
    {
        double S = 0;
        for (auto& ch : _deg_hist)
        {
            auto& bv = ch.first;
            auto& cdeg_hist = ch.second;
303

304
            size_t n_bv = _bhist.find(bv)->second;
305

306
307
308
309
310
            if (n_bv == 0)
                continue;

            const auto& bmh = _embhist.find(bv)->second;
            const auto& bph = _epbhist.find(bv)->second;
311

312
            for (size_t i = 0; i < bv.size(); ++i)
313
            {
314
315
316
317
318
319
320
321
322
                if (_directed)
                {
                    S += log_q(bmh[i], n_bv);
                    S += log_q(bph[i], n_bv);
                }
                else
                {
                    S += log_q(bph[i] - n_bv, n_bv);
                }
323
            }
324
325
326
327
328
329
330
331
332
333
334
335
336

            S += lgamma_fast(n_bv + 1);

            for (auto& dh : cdeg_hist)
                S -= lgamma_fast(dh.second + 1);
        }

        for (size_t r = 0; r < _r_count.size(); ++r)
        {
            if (_r_count[r] == 0)
                continue;
            S += lbinom(_r_count[r] + _emhist[r] - 1,  _emhist[r]);
            S += lbinom(_r_count[r] + _ephist[r] - 1,  _ephist[r]);
337
338
339
340
        }
        return S;
    }

341
342
343
344
345
346
347
348
349
350
351
352
353
354
    double get_deg_dl(int kind) const
    {
        switch (kind)
        {
        case deg_dl_kind::ENT:
            return get_deg_dl_ent();
        case deg_dl_kind::UNIFORM:
            return get_deg_dl_uniform();
        case deg_dl_kind::DIST:
            return get_deg_dl_dist();
        default:
            return numeric_limits<double>::quiet_NaN();
        }
    }
355
356

    template <class Graph>
357
358
    bool get_n_bv(size_t v, size_t r, size_t nr, const bv_t& bv,
                  const cdeg_t& deg, bv_t& n_bv, cdeg_t& n_deg, Graph& g,
359
                  size_t in_deg = 0, size_t out_deg = 0) const
360
    {
361
362
        size_t kin = (in_deg + out_deg == 0) ? in_degreeS()(v, g) : in_deg;
        size_t kout = (in_deg + out_deg == 0) ? out_degreeS()(v, g) : out_deg;
363

364
365
366
        gt_hash_map<size_t, std::pair<int, int>> deg_delta;

        auto& d_r = deg_delta[r];
367
368
        d_r.first -= kin;
        d_r.second -= kout;
369

370
        auto& d_nr = deg_delta[nr];
371
372
        d_nr.first += kin;
        d_nr.second += kout;
373
374
375
376

        n_deg.clear();
        n_bv.clear();
        bool is_same_bv = true;
377
        bool has_r = false, has_nr = false;
378
379
380
381
382
        for (size_t i = 0; i < bv.size(); ++i)
        {
            size_t s = bv[i];
            auto k_s = deg[i];

383
            auto& d_s = deg_delta[s];
384
385
386
387
388
389
390
391
392
393
            get<0>(k_s) += d_s.first;
            get<1>(k_s) += d_s.second;

            d_s.first = d_s.second = 0;

            if (s == r)
                has_r = true;

            if (s == nr)
                has_nr = true;
394
395
396
397
398
399
400
401
402
403
404
405

            if ((get<0>(k_s) + get<1>(k_s)) > 0)
            {
                n_bv.push_back(s);
                n_deg.push_back(k_s);
            }
            else
            {
                is_same_bv = false;
            }
        }

406
        if (!has_r || !has_nr)
407
408
        {
            is_same_bv = false;
409
            std::array<size_t, 2> ss = {{r, nr}};
410
            for (auto s : ss)
411
            {
412
                auto& d_s = deg_delta[s];
413
414
415
416
417
                if (d_s.first + d_s.second == 0)
                    continue;
                size_t kin = d_s.first;
                size_t kout = d_s.second;
                auto pos = std::lower_bound(n_bv.begin(), n_bv.end(), s);
418
419
                auto dpos = n_deg.begin();
                std::advance(dpos, pos - n_bv.begin());
420
                n_bv.insert(pos, s);
421
422
423
424
425
426
427
                n_deg.insert(dpos, make_pair(kin, kout));
            }
        }
        return is_same_bv;
    }

    // get deg counts without increasing the container
428
    size_t get_deg_count(const bv_t& bv, const cdeg_t& deg) const
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    {
        auto iter = _deg_hist.find(bv);
        if (iter == _deg_hist.end())
            return 0;
        auto& hist = iter->second;
        if (hist.empty())
            return 0;
        auto diter = hist.find(deg);
        if (diter == hist.end())
            return 0;
        return diter->second;
    }

    // get bv counts without increasing the container
443
    size_t get_bv_count(const bv_t& bv) const
444
445
446
447
448
449
450
451
    {
        auto iter = _bhist.find(bv);
        if (iter == _bhist.end())
            return 0;
        return iter->second;
    }

    template <class Graph>
452
453
    double get_delta_partition_dl(size_t v, size_t r, size_t nr, const Graph& g,
                                  size_t in_deg = 0, size_t out_deg = 0)
454
    {
455
        if (r == nr)
456
457
            return 0;

458
459
460
461
462
463
464
465
466
        size_t o_r = r;
        size_t o_nr = nr;
        r = get_r(r);
        nr = get_r(nr);

        size_t u = _overlap_stats.get_node(v);

        u = get_v(u);

467
        auto& bv = _bvs[u];
468
469
        assert(bv.size() > 0);
        bv_t n_bv;
470
        size_t d = bv.size();
471
472
        const cdeg_t& deg = _degs[u];
        cdeg_t n_deg;
473

474
475
        bool is_same_bv = get_n_bv(v, r, nr, bv, deg, n_bv, n_deg, g, in_deg,
                                   out_deg);
476

477
        assert(n_bv.size() > 0);
478
479
480
        if (is_same_bv)
            return 0;

481
482
483
        size_t n_d = n_bv.size();
        size_t n_D = _D;

484
        if (d == _D && n_d < d && _dhist[d] == 1)
485
486
487
488
489
490
491
492
493
494
495
496
        {
            n_D = 1;
            for (auto& bc : _bhist)
            {
                if (bc.first.size() == d || bc.second == 0)
                    continue;
                n_D = max(n_D, bc.first.size());
            }
        }

        n_D = max(n_D, n_d);

497
498
        double S_a = 0, S_b = 0;

499
500
        if (n_D != _D)
        {
501
502
            S_b += lbinom_fast(_D  + _N - 1, _N);
            S_a += lbinom_fast(n_D + _N - 1, _N);
503
504
        }

505
        int dB = 0;
506
        if (_overlap_stats.virtual_remove_size(v, o_r, in_deg, out_deg) == 0)
507
            dB--;
508
        if (_overlap_stats.get_block_size(o_nr) == 0)
509
            dB++;
510

511
        auto get_S_d = [&] (size_t d_i, int delta, int dB) -> double
512
513
514
515
            {
                int nd = int(_dhist[d_i]) + delta;
                if (nd == 0)
                    return 0.;
516
517
518
519
520
                double x;
                if (_allow_empty)
                    x = lbinom_fast(_total_B + dB, d_i);
                else
                    x = lbinom_fast(_actual_B + dB, d_i);
521
                double S = lbinom_careful(exp(x) + nd - 1, nd); // not fast
522
523
524
525
526
                if (std::isinf(S) || std::isnan(S))
                    S = nd * x - lgamma_fast(nd + 1);
                return S;
            };

527
        if (dB == 0 || _allow_empty)
528
529
530
531
532
533
534
535
        {
            if (n_d != d)
            {
                S_b += get_S_d(d,  0, 0) + get_S_d(n_d, 0, 0);
                S_a += get_S_d(d, -1, 0) + get_S_d(n_d, 1, 0);
            }
        }
        else
536
        {
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            for (size_t di = 0; di < min(_actual_B + abs(dB) + 1, _dhist.size()); ++di)
            {
                if (d != n_d)
                {
                    if (di == d)
                    {
                        S_b += get_S_d(d,  0, 0);
                        S_a += get_S_d(d, -1, dB);
                        continue;
                    }
                    if (di == n_d)
                    {
                        S_b += get_S_d(n_d, 0, 0);
                        S_a += get_S_d(n_d, 1, dB);
                        continue;
                    }
                }
                if (_dhist[di] == 0)
                    continue;
                S_b += get_S_d(di, 0, 0);
                S_a += get_S_d(di, 0, dB);
            }
559
560
        }

561
        size_t bv_count = get_bv_count(bv);
562
        assert(bv_count > 0);
563
        size_t n_bv_count = get_bv_count(n_bv);
564
565
566

        auto get_S_b = [&] (bool is_bv, int delta) -> double
            {
567
                assert(int(bv_count) + delta >= 0);
568
569
570
571
572
                if (is_bv)
                    return -lgamma_fast(bv_count + delta + 1);
                return -lgamma_fast(n_bv_count + delta + 1);
            };

573
574
        S_b += get_S_b(true,  0) + get_S_b(false, 0);
        S_a += get_S_b(true, -1) + get_S_b(false, 1);
575

576
577
        return S_a - S_b;
    }
578

579
    template <class Graph>
580
581
    double get_delta_edges_dl(size_t v, size_t r, size_t nr, size_t actual_B,
                              const Graph&)
582
    {
583
        if (r == nr || _allow_empty)
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
            return 0;

        double S_b = 0, S_a = 0;

        int dB = 0;
        if (_overlap_stats.virtual_remove_size(v, r) == 0)
            dB--;
        if (_overlap_stats.get_block_size(nr) == 0)
            dB++;

        if (dB != 0)
        {
            auto get_x = [](size_t B) -> size_t
                {
                    if (is_directed::apply<Graph>::type::value)
                        return B * B;
                    else
                        return (B * (B + 1)) / 2;
                };

604
605
            S_b += lbinom(get_x(actual_B) + _E - 1, _E);
            S_a += lbinom(get_x(actual_B + dB) + _E - 1, _E);
606
607
608
609
610
        }

        return S_a - S_b;
    }

611

612
    template <class Graph, class EWeight>
613
614
615
    double get_delta_deg_dl(size_t v, size_t r, size_t nr, const EWeight&,
                            const Graph& g, size_t in_deg = 0,
                            size_t out_deg = 0)
616
617
618
    {
        if (r == nr)
            return 0;
619

620
621
622
        r = get_r(r);
        nr = get_r(nr);

623
        double S_b = 0, S_a = 0;
624

625
        size_t u = get_v(_overlap_stats.get_node(v));
626
        auto& bv = _bvs[u];
627
        bv_t n_bv;
628

629
630
        const cdeg_t& deg = _degs[u];
        cdeg_t n_deg;
631

632
633
        bool is_same_bv = get_n_bv(v, r, nr, bv, deg, n_bv, n_deg, g, in_deg,
                                   out_deg);
634

635
636
        size_t bv_count = get_bv_count(bv);
        size_t n_bv_count = bv_count;
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        auto get_S_bv = [&] (bool is_bv, int delta) -> double
            {
                if (is_bv)
                    return lgamma_fast(bv_count + delta + 1);
                return lgamma_fast(n_bv_count + delta + 1);
            };

        auto get_S_e = [&] (bool is_bv, int bdelta, int deg_delta) -> double
            {
                size_t bv_c = ((is_bv) ? bv_count : n_bv_count) + bdelta;
                if (bv_c == 0)
                    return 0.;

                const cdeg_t& deg_i = (is_bv) ? deg : n_deg;
                const auto& bv_i = (is_bv) ? bv : n_bv;
653

654
655
                double S = 0;
                if (((is_bv) ? bv_count : n_bv_count) > 0)
656
                {
657
658
                    const auto& bmh = _embhist.find(bv_i)->second;
                    const auto& bph = _epbhist.find(bv_i)->second;
659

660
661
662
663
                    assert(bmh.size() == bv_i.size());
                    assert(bph.size() == bv_i.size());

                    for (size_t i = 0; i < bv_i.size(); ++i)
664
                    {
665
666
667
668
669
670
671
672
673
                        if (_directed)
                        {
                            S += log_q(size_t(bmh[i] + deg_delta * int(get<0>(deg_i[i]))), bv_c);
                            S += log_q(size_t(bph[i] + deg_delta * int(get<1>(deg_i[i]))), bv_c);
                        }
                        else
                        {
                            S += log_q(size_t(bph[i] + deg_delta * int(get<1>(deg_i[i]))) - bv_c, bv_c);
                        }
674
                    }
675
676
677
678
679
                }
                else
                {
                    for (size_t i = 0; i < bv_i.size(); ++i)
                    {
680
681
682
683
684
685
686
687
688
                        if (_directed)
                        {
                            S += log_q(size_t(deg_delta * int(get<0>(deg_i[i]))), bv_c);
                            S += log_q(size_t(deg_delta * int(get<1>(deg_i[i]))), bv_c);
                        }
                        else
                        {
                            S += log_q(size_t(deg_delta * int(get<1>(deg_i[i]))) - bv_c, bv_c);
                        }
689
690
                    }
                }
691

692
693
                return S;
            };
694

695
        auto get_S_e2 = [&] (int deg_delta, int ndeg_delta) -> double
696
            {
697
698
699
                double S = 0;
                const auto& bmh = _embhist.find(bv)->second;
                const auto& bph = _epbhist.find(bv)->second;
700

701
                for (size_t i = 0; i < bv.size(); ++i)
702
                {
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
                    if (_directed)
                    {
                        S += log_q(size_t(bmh[i] +
                                          deg_delta * int(get<0>(deg[i])) +
                                          ndeg_delta * int(get<0>(n_deg[i]))),
                                   bv_count);
                        S += log_q(size_t(bph[i] +
                                          deg_delta * int(get<1>(deg[i])) +
                                          ndeg_delta * int(get<1>(n_deg[i]))),
                                   bv_count);
                    }
                    else
                    {
                        S += log_q(size_t(bph[i] +
                                          deg_delta * int(get<1>(deg[i])) +
                                          ndeg_delta * int(get<1>(n_deg[i])))
                                   - bv_count,
                                   bv_count);
                    }
722
723
724
                }
                return S;
            };
725

726
727
        if (!is_same_bv)
        {
728
729
            n_bv_count = get_bv_count(n_bv);

730
731
            S_b += get_S_bv(true,  0) + get_S_bv(false, 0);
            S_a += get_S_bv(true, -1) + get_S_bv(false, 1);
732

733
734
735
736
737
738
739
740
741
742
743
            S_b += get_S_e(true,  0,  0) + get_S_e(false, 0, 0);
            S_a += get_S_e(true, -1, -1) + get_S_e(false, 1, 1);
        }
        else
        {
            S_b += get_S_e2( 0, 0);
            S_a += get_S_e2(-1, 1);
        }

        size_t deg_count = get_deg_count(bv, deg);
        size_t n_deg_count = get_deg_count(n_bv, n_deg);
744

745
        auto get_S_deg = [&] (bool is_deg, int delta) -> double
746
            {
747
748
749
750
                if (is_deg)
                    return -lgamma_fast(deg_count + delta + 1);
                return -lgamma_fast(n_deg_count + delta + 1);
            };
751

752
753
754
755
        S_b += get_S_deg(true,  0) + get_S_deg(false, 0);
        S_a += get_S_deg(true, -1) + get_S_deg(false, 1);

        auto is_in = [&] (const bv_t& bv, size_t r) -> bool
756
            {
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
                auto iter = lower_bound(bv.begin(), bv.end(), r);
                if (iter == bv.end())
                    return false;
                if (size_t(*iter) != r)
                    return false;
                return true;
            };

        for (size_t s : bv)
        {
            S_b += lbinom_fast(_r_count[s] + _emhist[s] - 1, _emhist[s]);
            S_b += lbinom_fast(_r_count[s] + _ephist[s] - 1, _ephist[s]);
        }

        for (size_t s : n_bv)
        {
            if (is_in(bv, s))
                continue;
            S_b += lbinom_fast(_r_count[s] + _emhist[s] - 1, _emhist[s]);
            S_b += lbinom_fast(_r_count[s] + _ephist[s] - 1, _ephist[s]);
        }
778

779
780
        gt_hash_map<size_t, pair<int, int>> deg_delta;
        gt_hash_map<size_t, int> r_count_delta;
781

782
783
784
        if (bv != n_bv)
        {
            if (n_bv_count == 0)
785
            {
786
787
                for (auto s : n_bv)
                    r_count_delta[s] += 1;
788
789
            }

790
            if (bv_count == 1)
791
            {
792
793
794
795
                for (auto s : bv)
                   r_count_delta[s] -= 1;
            }
        }
796

797
798
799
800
        if (r != nr)
        {
            size_t kin = (in_deg + out_deg == 0) ? in_degreeS()(v, g) : in_deg;
            size_t kout = (in_deg + out_deg == 0) ? out_degreeS()(v, g) : out_deg;
801

802
803
804
            auto& d_r = deg_delta[r];
            d_r.first -= kin;
            d_r.second -= kout;
805

806
807
808
809
810
811
812
813
814
815
816
817
            auto& d_nr = deg_delta[nr];
            d_nr.first += kin;
            d_nr.second += kout;
        }

        for (size_t s : bv)
        {
            S_a += lbinom_fast(_r_count[s] + r_count_delta[s] + _emhist[s] + deg_delta[s].first - 1,
                               _emhist[s] + deg_delta[s].first);
            S_a += lbinom_fast(_r_count[s] + r_count_delta[s] + _ephist[s] + deg_delta[s].second - 1,
                               _ephist[s] + deg_delta[s].second);
        }
818

819
820
821
        for (size_t s : n_bv)
        {
            if (!is_in(bv, s))
822
            {
823
824
825
826
                S_a += lbinom_fast(_r_count[s] + r_count_delta[s] + _emhist[s] + deg_delta[s].first - 1,
                                   _emhist[s] + deg_delta[s].first);
                S_a += lbinom_fast(_r_count[s] + r_count_delta[s] + _ephist[s] + deg_delta[s].second - 1,
                                   _ephist[s] + deg_delta[s].second);
827
828
829
            }
        }

830
        return S_a - S_b;
831
832
    }

833
834
    template <class Graph>
    void move_vertex(size_t v, size_t r, size_t nr, bool, Graph& g,
835
                     size_t in_deg = 0, size_t out_deg = 0)
836
    {
837
        if (r == nr)
838
839
            return;

840
841
842
843
844
        r = get_r(r);
        nr = get_r(nr);

        auto u =_overlap_stats.get_node(v);
        u = get_v(u);
845
        auto& bv = _bvs[u];
846
847
        assert(!bv.empty());
        bv_t n_bv;
848
        cdeg_t& deg = _degs[u];
849
        cdeg_t n_deg;
850
851
        size_t d = bv.size();

852
853
        bool is_same_bv = get_n_bv(v, r, nr, bv, deg, n_bv, n_deg, g, in_deg,
                                   out_deg);
854
        assert(!n_bv.empty());
855
856
857
858
859
860
861
862
863
864
        size_t n_d = n_bv.size();

        if (!is_same_bv)
        {
            _dhist[d] -= 1;
            auto& bv_count = _bhist[bv];
            bv_count -= 1;

            if (bv_count == 0)
            {
865
                _bhist.erase(bv);
866
                for (auto s : bv)
867
                {
868
                    _r_count[s]--;
869
870
871
                    if (_r_count[s] == 0)
                        _actual_B--;
                }
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
            }

            if (d == _D && _dhist[d] == 0)
            {
                _D = 1;
                for (auto& bc : _bhist)
                {
                    if (bc.second == 0)
                        continue;
                    _D = max(_D, bc.first.size());
                }
            }

            _dhist[n_d] += 1;
            auto& n_bv_count = _bhist[n_bv];
            n_bv_count += 1;

            if (n_bv_count == 1)
            {
                for (auto s : n_bv)
892
893
894
                {
                    if (_r_count[s] == 0)
                        _actual_B++;
895
                    _r_count[s]++;
896
                }
897
898
899
900
901
902
            }

            if (n_d > _D)
                _D = n_d;
        }

903
904
905
906
907
        auto& deg_h = _deg_hist[bv];
        auto& deg_count = deg_h[deg];
        deg_count -= 1;
        if (deg_count == 0)
            deg_h.erase(deg);
908
909
910
911
912
913
914
915
916
917
        auto& bmh = _embhist[bv];
        auto& bph = _epbhist[bv];
        assert(bmh.size() == bv.size());
        assert(bph.size() == bv.size());
        for (size_t i = 0; i < bv.size(); ++i)
        {
            bmh[i] -= get<0>(deg[i]);
            bph[i] -= get<1>(deg[i]);
        }

918
        if (deg_h.empty())
919
        {
920
921
922
            _deg_hist.erase(bv);
            _embhist.erase(bv);
            _epbhist.erase(bv);
923
924
        }

925
926
927
928
929
        size_t kin = (in_deg + out_deg == 0) ? in_degreeS()(v, g) : in_deg;
        size_t kout = (in_deg + out_deg == 0) ? out_degreeS()(v, g) : out_deg;
        _emhist[r] -= kin;
        _ephist[r] -= kout;

930
931
932
933
934
935
936
937
938
939
940
941
        auto& hist = _deg_hist[n_bv];
        hist[n_deg] += 1;
        auto& n_bmh = _embhist[n_bv];
        auto& n_bph = _epbhist[n_bv];
        n_bmh.resize(n_bv.size());
        n_bph.resize(n_bv.size());
        for (size_t i = 0; i < n_bv.size(); ++i)
        {
            n_bmh[i] += get<0>(n_deg[i]);
            n_bph[i] += get<1>(n_deg[i]);
        }

942
943
        _emhist[nr] += kin;
        _ephist[nr] += kout;
944
945
946

        _bvs[u].swap(n_bv);
        _degs[u].swap(n_deg);
947
        assert(_bvs[u].size() > 0);
948

949
    }
950

951
952
953
954
955
    size_t get_actual_B()
    {
        return _actual_B;
    }

956
957
958
959
960
    void add_block()
    {
        _total_B++;
    }

961
private:
962
963
964
    overlap_stats_t& _overlap_stats;
    vector<size_t>& _bmap;
    vector<size_t>& _vmap;
965
    size_t _N;
966
    size_t _E;
967
    size_t _actual_B;
968
    size_t _total_B;
969
    size_t _D;
970
    bool _allow_empty;
971
    bool _directed;
972
973
974
    vector<int> _dhist;        // d-histogram
    vector<int> _r_count;      // m_r
    bhist_t _bhist;            // b-histogram
975
976
    vector<size_t> _emhist;    // e-_r histogram
    vector<size_t> _ephist;    // e+_r histogram
977
978
979
980
981
    ebhist_t _embhist;         // e+^r_b histogram
    ebhist_t _epbhist;         // e+^r_b histogram
    deg_hist_t _deg_hist;      // n_k^b histogram
    vector<bv_t> _bvs;         // bv node map
    vector<cdeg_t> _degs;      // deg node map
982
983
984
985
};

} // namespace graph_tool

986
#endif // GRAPH_BLOCKMODEL_OVERLAP_PARTITION_HH