Function psolve
enables the operation on stacks of (possibly singular) matrices, similar to numpy.linalg.lstsq
and numpy.linalg.solve
import numpy as np
def pinv(a, rcond=1e-15):
"""Take from numpy
cf. https://github.com/numpy/numpy/blob/master/numpy/linalg/linalg.py
"""
rcond = np.asarray(rcond)
u, s, vt = np.linalg.svd(a, full_matrices=False)
# discard small singular values
cutoff = rcond[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
large = s > cutoff
s = np.divide(1, s, where=large, out=s)
s[~large] = 0
res = np.matmul(np.swapaxes(vt, -1, -2),
np.multiply(s[..., np.newaxis], np.swapaxes(u, -1, -2)))
return res
def psolve(a, b):
"""Solve a x = b
"""
try:
x = np.linalg.solve(a, b)
except np.linalg.LinAlgError:
newaxis = (a.ndim == b.ndim + 1)
if newaxis:
b = b[..., np.newaxis]
ai = pinv(a)
x = np.matmul(ai, b)
if newaxis:
x = x[..., 0]
return x