syrte
7/11/2016 - 2:59 AM

hist2d_stats.py

import numpy as np
from matplotlib import pyplot as plt

def hist_stats(x, y, bins=10, func=np.mean, nmin=None, **kwds):
    """
    x = np.random.rand(1000)
    hist_stats(x, x, func=lambda x:np.percentile(x, [15,50,85]), 
               ls=['--', '-', '--'], lw=[1, 2, 1])
    """
    stats, edges, count = binstats(x, y, bins=bins, func=func, nmin=nmin)
    stats = np.atleast_2d(stats.T)
    assert len(edges) == 1
    assert stats.ndim == 2
    
    X = (edges[0][:-1] + edges[0][1:])/2.
    lines = []
    for i,Y in enumerate(stats):
        args = {k:(v if np.isscalar(v) else v[i]) for k, v in kwds.items()}
        lines += plt.plot(X, Y, **args)
    return lines
    

def hist2d_stats(x, y, z, bins=10, func=np.mean, nmin=None, **kwds):
    stats, edges, count = binstats([x, y], z, bins=bins, func=func, nmin=nmin)
    (X, Y), Z = edges, stats.T
    mask = ~np.isfinite(Z)
    Z = np.ma.array(Z, mask=mask)
    kwds.setdefault('vmin', Z.min())
    kwds.setdefault('vmax', Z.max())
    return plt.pcolormesh(X, Y, Z, **kwds)