# cython: boundscheck=False, wraparound=False
# distutils: extra_compile_args = -fopenmp
# distutils: extra_link_args = -fopenmp
from cython.parallel import prange
from libc.math cimport exp, sqrt
import numpy as np
cdef double pi = np.pi
cdef double tau = np.pi * 2
def binorm(double[:] a, double[:] b, double[:] x, double[:] y,
           double[:] xvar, double[:] yvar, double[:] xyvar, 
           int nthread=1):
    """binorm(a, b, x, y, xvar, yvar, xyvar, nthread=1)
    Return the probability density of data points in given model.
    The model is described by points (x, y, xvar, yvar, xyvar).
    a, b :
        data points
    x, y, xvar, yvar, xyvar :
        model points and weights
    nthread :
        number of threads
    """
    cdef:
        int m, n, i, j
        double aj, bj, sx, sy, sxy, det, dx, dy
        double p, psum
        double[:] out
    assert x.size == y.size == xvar.size == yvar.size == xyvar.size
    assert a.size == b.size
    m, n = x.size, a.size
    out = np.empty(n, dtype='double')
    
    for j in prange(n, nogil=True, num_threads=nthread):
        aj, bj = a[j], b[j]
        psum = 0.
        for i in range(m):
            dx, dy = aj - x[i], bj - y[i]
            sx, sy, sxy = xvar[i], yvar[i], xyvar[i]
            det = sx * sy - sxy * sxy
            p = (2 * dx * dy * sxy - dx * dx * sy - dy * dy * sx) / 2. / det
            psum = psum + exp(p) / sqrt(det)
        out[j] = psum / tau  / m
    return out.base
def binorm_alt(double[:] x, double[:] y, 
           double[:] xvar, double[:] yvar, double[:] xyvar, 
           double[:] a, double[:] b, double[:] w,
           int nthread=1):
    """binorm(x, y, xvar, yvar, xyvar, a, b, w, nthread=1)
    Return the probability density of data points in given model.
    The model is described by points (a, b, w).
    a, b, w :
        model points and weights
    x, y, xvar, yvar, xyvar :
        data points
    nthread :
        number of threads
    """
    cdef:
        int m, n, i, j
        double xi, yi, sx, sy, sxy, det, dx, dy
        double p, psum, wsum
        double[:] out
    assert x.size == y.size == xvar.size == yvar.size == xyvar.size
    assert a.size == b.size == w.size
    m, n = x.size, a.size
    out = np.empty(m, dtype='double')
    
    wsum = 0
    for j in range(n):
        wsum += w[j]
    # for i in range(m):
    for i in prange(m, nogil=True, num_threads=nthread):
        xi, yi = x[i], y[i]
        sx, sy, sxy = xvar[i], yvar[i], xyvar[i]
        det = sx * sy - sxy * sxy
        psum = 0.
        for j in range(n):
            dx, dy = a[j] - xi, b[j] - yi
            p = (2 * dx * dy * sxy - dx * dx * sy - dy * dy * sx) / 2. / det
            psum = psum + exp(p) * w[j]
        out[i] = psum / tau / sqrt(det) / wsum
    return out.base