Commit 41dc36a4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference.mcmc.TemperingState: Keep track of state indices

parent d37c0de0
...@@ -601,11 +601,13 @@ class TemperingState(object): ...@@ -601,11 +601,13 @@ class TemperingState(object):
Inverse temperature values. Inverse temperature values.
""" """
def __init__(self, states, betas): def __init__(self, states, betas, idx=None):
if not (len(states) == len(betas)): if not (len(states) == len(betas)):
raise ValueError("states and betas must be of the same size") raise ValueError("states and betas must be of the same size")
self.states = states self.states = states
self.betas = betas self.betas = betas
if idx is None:
self.idx = list(range(len(betas)))
def entropy(self, **kwargs): def entropy(self, **kwargs):
"""Returns the sum of the entropy of the parallel states. All keyword """Returns the sum of the entropy of the parallel states. All keyword
...@@ -660,8 +662,8 @@ class TemperingState(object): ...@@ -660,8 +662,8 @@ class TemperingState(object):
ddS = -(P1_f + P2_f - P1_b - P2_b) ddS = -(P1_f + P2_f - P1_b - P2_b)
if ddS < 0 or numpy.random.random() < exp(-ddS): if ddS < 0 or numpy.random.random() < exp(-ddS):
self.states[j], self.states[i] = \ self.states[j], self.states[i], self.idx[j], self.idx[i] = \
self.states[i], self.states[j] self.states[i], self.states[j], self.idx[i], self.idx[j]
nswaps += 1 nswaps += 1
dS += ddS dS += ddS
if check_verbose(verbose): if check_verbose(verbose):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment