syrte
7/10/2016 - 5:19 AM

nd version

nd version

from __future__ import division, print_function, absolute_import

import warnings
import numpy as np
from collections import namedtuple

BinStats = namedtuple('BinStats',
                      ('stats', 'bin_edges', 'bin_count'))


def binstats(xs, ys, bins=10, func=np.mean, nmin=None):
    """
    xs: array_like or list of array_like
        Data to histogram passed as a sequence of D arrays of length N, or
        as an (D, N) array.
    ys: array_like or list of array_like
        The data on which the `func` will be computed.  This must be
        the same shape as `x`, or a list of sequences - each with the same
        shape as `x`.  If `values` is a list, the `func` will treat them as 
        multiple arguments.
    bins : sequence or int, optional
        The bin specification must be in one of the following forms:
          * A sequence of arrays describing the bin edges along each dimension.
          * The number of bins for each dimension (n1, n2, ... = bins).
          * The number of bins for all dimensions (n1 = n2 = ... = bins).
    func: callable
        User-defined function which takes a sequece of arrays as input,
        and outputs a scalar or array. This function will be called on 
        the values in each bin.
        Empty bins will be represented by func([[]]*n_ys)
    nmin: int
        The bin with points counts smaller than nmin will be treat as empty bin.

    Returns
    -------
    stats: ndarray
        The values of the selected statistic in each bin.
    bin_edges: list of ndarray
        A list of D arrays describing the (nxi + 1) bin edges for each
        dimension.
    bin_count: ndarray
        Number count in each bin.

    See Also
    numpy.histogramdd, scipy.stats.binned_statistic_dd
    """
    assert hasattr(xs, '__len__') and len(xs) > 0
    if np.isscalar(xs[0]):
        xs = [np.asarray(xs)]
        bins = [bins]
    else:
        xs = [np.asarray(x) for x in xs]
        assert len(xs[0]) > 0
    # `D`: number of dimensions
    # `N`: length of elements along each dimension
    D, N = len(xs), len(xs[0])
    for x in xs:
        assert len(x) == N and x.ndim == 1

    assert hasattr(ys, '__len__') and len(ys) > 0
    if np.isscalar(ys[0]):
        ys = [np.asarray(ys)]
    else:
        ys = [np.asarray(y) for y in ys]
    for y in ys:
        assert len(y) == N

    if np.isscalar(bins):
        bins = [bins] * D
    else:
        assert len(bins) == D

    edges = [None for _ in range(D)]
    for i, bin in enumerate(bins):
        if np.isscalar(bin):
            x = xs[i]
            xmin, xmax = np.nanmin(x), np.nanmax(x)
            if xmin == xmax:
                xmin = xmin - 0.5
                xmax = xmax + 0.5
            else:
                xmax = xmax + xmax * 1e-10
            assert xmax > xmin
            edges[i] = np.linspace(xmin, xmax, bin + 1)
        else:
            edges[i] = np.asarray(bin)
    dims = tuple(len(edge) - 1 for edge in edges)

    with warnings.catch_warnings():
        # Numpy generates a warnings for mean/std/... with empty list
        warnings.filterwarnings('ignore', category=RuntimeWarning)
        try:
            yselect = [[] for y in ys]
            null = np.asarray(func(*yselect))
        except:
            yselect = [y[:1] for y in ys]
            test = np.asarray(func(*yselect))
            null = np.full_like(test, np.nan, dtype='float')

    idx = np.empty((D, N), dtype='int')
    for i in range(D):
        ix = np.searchsorted(edges[i], xs[i], side='right') - 1
        ix[ix >= dims[i]] = -1
        idx[i] = ix
    idx_ravel = np.ravel_multi_index(idx, dims, mode='clip')
    idx_ravel[(idx < 0).any(axis=0)] = -1

    res = np.empty(dims + null.shape, dtype=null.dtype)
    cnt = np.empty(dims, dtype='int')
    res_ravel = res.reshape((-1,) + null.shape)
    cnt_ravel = cnt.ravel()

    idx_set = np.unique(idx_ravel)
    for i in range(cnt.size):
        if i in idx_set:
            ix = (idx_ravel == i)
            yselect = [y[ix] for y in ys]
            res_ravel[i] = func(*yselect)
            cnt_ravel[i] = len(yselect[0])
        else:
            res_ravel[i] = null
            cnt_ravel[i] = 0

    if nmin is not None:
        res_ravel[cnt_ravel < nmin] = null

    return BinStats(res, edges, cnt)


if __name__ == '__main__':
    import numpy as np
    from numpy.random import rand
    x = rand(1000)
    b = np.linspace(0, 1, 11)
    binstats(x, x, 10, np.mean)
    binstats(x, x, b, np.mean)
    binstats(x, x, b, np.mean, nmin=100)
    binstats(x, [x, x], 10, lambda x, y: np.mean(x + y))
    binstats(x, [x, x], 10, lambda x, y: [np.mean(x), np.std(y)])
    binstats([x, x], x, (10, 10), np.mean)
    binstats([x, x], x, [b, b], np.mean)
    binstats([x, x], [x, x], 10, lambda x, y: [np.mean(x), np.std(y)])

    b1 = np.linspace(0, 1, 6)
    b2 = np.linspace(0, 1, 11)
    binstats([x, x], [x, x], [b1, b2], lambda x, y: [np.mean(x), np.std(y)])

    from scipy.stats import binned_statistic_dd
    print (binned_statistic_dd(x, x, 'std', bins=[b])[:2])
    print (binstats(x, x, bins=b, func=np.std)[:2])