syrte
10/23/2017 - 11:16 AM

psolve

Function psolve enables the operation on stacks of (possibly singular) matrices, similar to numpy.linalg.lstsq and numpy.linalg.solve

import numpy as np

def pinv(a, rcond=1e-15):
    """Take from numpy
    cf. https://github.com/numpy/numpy/blob/master/numpy/linalg/linalg.py
    """
    rcond = np.asarray(rcond)
    u, s, vt = np.linalg.svd(a, full_matrices=False)

    # discard small singular values
    cutoff = rcond[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
    large = s > cutoff
    s = np.divide(1, s, where=large, out=s)
    s[~large] = 0

    res = np.matmul(np.swapaxes(vt, -1, -2),
                    np.multiply(s[..., np.newaxis], np.swapaxes(u, -1, -2)))
    return res


def psolve(a, b):
    """Solve a x = b
    """
    try:
        x = np.linalg.solve(a, b)
    except np.linalg.LinAlgError:
        newaxis = (a.ndim == b.ndim + 1)
        if newaxis:
            b = b[..., np.newaxis]
        ai = pinv(a)
        x = np.matmul(ai, b)
        if newaxis:
            x = x[..., 0]
    return x