Commit 41dc36a4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference.mcmc.TemperingState: Keep track of state indices

parent d37c0de0
......@@ -601,11 +601,13 @@ class TemperingState(object):
Inverse temperature values.
def __init__(self, states, betas):
def __init__(self, states, betas, idx=None):
if not (len(states) == len(betas)):
raise ValueError("states and betas must be of the same size")
self.states = states
self.betas = betas
if idx is None:
self.idx = list(range(len(betas)))
def entropy(self, **kwargs):
"""Returns the sum of the entropy of the parallel states. All keyword
......@@ -660,8 +662,8 @@ class TemperingState(object):
ddS = -(P1_f + P2_f - P1_b - P2_b)
if ddS < 0 or numpy.random.random() < exp(-ddS):
self.states[j], self.states[i] = \
self.states[i], self.states[j]
self.states[j], self.states[i], self.idx[j], self.idx[i] = \
self.states[i], self.states[j], self.idx[i], self.idx[j]
nswaps += 1
dS += ddS
if check_verbose(verbose):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment