Skip to content

Commit

Permalink
Fixes #13
Browse files Browse the repository at this point in the history
Adding ability to remove chains
  • Loading branch information
Samreay committed Dec 19, 2016
1 parent 812190b commit 8e85db5
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.15.4
* Adding ability to remove chains.

##### 0.15.3
* Adding ability to plot the walks of multiple chains together.
Expand Down
61 changes: 59 additions & 2 deletions chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChainConsumer(object):
""" A class for consuming chains produced by an MCMC walk
"""
__version__ = "0.15.3"
__version__ = "0.15.4"

def __init__(self):
logging.basicConfig()
Expand Down Expand Up @@ -163,6 +163,59 @@ def add_chain(self, chain, parameters=None, name=None, weights=None, posterior=N
self._init_params()
return self

def remove_chain(self, chain=-1):
"""
Removes a chain from ChainConsumer. Calling this will require any configurations set to be redone!
Parameters
----------
chain : int|str, list[str]
The chain(s) to remove. You can pass in either the chain index, or the chain name, to remove it.
By default removes the last chain added.
Returns
-------
ChainConsumer
Itself, to allow chaining calls.
"""
if isinstance(chain, str) or isinstance(chain, int):
chain = [chain]
elif isinstance(chain, list):
for c in chain:
assert isinstance(c, str), "If you specify a list, " \
"you must specify chain names, not indexes." \
"This is to avoid confusion when specifying," \
"for example, [0,0]. As this might be an error," \
"or a request to remove the first two chains."
for c in chain:
index = self._get_chain(c)
parameters = self._parameters[index]

del self._chains[index]
del self._names[index]
del self._weights[index]
del self._posteriors[index]
del self._parameters[index]
del self._grids[index]
del self._num_free[index]
del self._num_data[index]

# Recompute all_parameters
for p in parameters:
has = False
for ps in self._parameters:
if p in ps:
has = True
break
if not has:
i = self._all_parameters.index(p)
del self._all_parameters[i]

# Need to reconfigure
self._init_params()

return self

def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
serif=True, sigmas=None, summary=None, bins=None, rainbow=None,
colors=None, linestyles=None, linewidths=None, kde=False, smooth=None,
Expand Down Expand Up @@ -991,7 +1044,11 @@ def plot_walks(self, parameters=None, truth=None, extents=None, display=False,
chains = [chains]
chains = [self._get_chain(c) for c in chains]

all_parameters = list(set([p for i in chains for p in self._parameters[i]]))
all_parameters2 = [p for i in chains for p in self._parameters[i]]
all_parameters = []
for p in all_parameters2:
if p not in all_parameters:
all_parameters.append(p)

if parameters is None:
parameters = all_parameters
Expand Down
1 change: 1 addition & 0 deletions doc/chain_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ General Methods
---------------
* :func:`chainconsumer.ChainConsumer.add_chain` - Add a chain!
* :func:`chainconsumer.ChainConsumer.divide_chain` - Split a chain into multiple chains to inspect each walk.
* :func:`chainconsumer.ChainConsumer.remove_chain` - Remove a chain.

Plotting Methods
----------------
Expand Down
79 changes: 79 additions & 0 deletions test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,82 @@ def test_dic_posterior_dependence(self):
dic2 = 2 * np.mean(-2 * p2) + 2 * norm.logpdf(0, scale=2)
assert np.isclose(bics[0], dic1 - dic2, atol=1e-3)

def test_remove_last_chain(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data)
consumer.add_chain(self.data * 2)
consumer.remove_chain()
consumer.configure(bins=1.6)
summary = consumer.get_summary()
assert isinstance(summary, dict)
actual = np.array(list(summary.values())[0])
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_remove_first_chain(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data * 2)
consumer.add_chain(self.data)
consumer.remove_chain(chain=0)
consumer.configure(bins=1.6)
summary = consumer.get_summary()
assert isinstance(summary, dict)
actual = np.array(list(summary.values())[0])
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_remove_chain_by_name(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data * 2, name="a")
consumer.add_chain(self.data, name="b")
consumer.remove_chain(chain="a")
consumer.configure(bins=1.6)
summary = consumer.get_summary()
assert isinstance(summary, dict)
actual = np.array(list(summary.values())[0])
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_remove_chain_recompute_params(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data * 2, parameters=["p1"], name="a")
consumer.add_chain(self.data, parameters=["p2"], name="b")
consumer.remove_chain(chain="a")
consumer.configure(bins=1.6)
summary = consumer.get_summary()
assert isinstance(summary, dict)
assert "p2" in summary
assert "p1" not in summary
actual = np.array(list(summary.values())[0])
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_remove_multiple_chains(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data * 2, parameters=["p1"], name="a")
consumer.add_chain(self.data, parameters=["p2"], name="b")
consumer.add_chain(self.data * 3, parameters=["p3"], name="c")
consumer.remove_chain(chain=["a", "c"])
consumer.configure(bins=1.6)
summary = consumer.get_summary()
assert isinstance(summary, dict)
assert "p2" in summary
assert "p1" not in summary
assert "p3" not in summary
actual = np.array(list(summary.values())[0])
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_remove_multiple_chains_fails(self):
with pytest.raises(AssertionError):
ChainConsumer().add_chain(self.data).remove_chain(chain=[0, 0])

0 comments on commit 8e85db5

Please sign in to comment.