syrte
4/4/2020 - 1:12 PM

Make transformation to a cKDTree object without rebuilding the tree.

Make transformation to a cKDTree object without rebuilding the tree.

    from scipy.spatial import cKDTree

    transform = {0: lambda x: -x}

    a = np.random.randn(1000, 2)
    b = a.copy()
    for k in transform:
        b[:, k] = transform[k](b[:, k])

    c = np.random.randn(1000, 2)
    d = c.copy()
    for k in fs:
        d[:, k] = transform[k](d[:, k])


    t0 = cKDTree(a)
    t1 = transform_ckdtree(t0, transform)
    t2 = cKDTree(b)

    p0 = hstack(t0.query_ball_point(c, 1))
    p1 = hstack(t1.query_ball_point(d, 1))
    p2 = hstack(t2.query_ball_point(d, 1))
    assert (p0 == p1).all()
    assert (p0 == p2).all()
import numpy as np
from scipy.spatial import cKDTree


ckdtreenode_dtype = np.dtype([
    ('split_dim', np.intp),
    ('children', np.intp),
    ('split', np.float64),
    ('start_idx', np.intp),
    ('end_idx', np.intp),
    ('less', np.intp),
    ('greater', np.intp),
    ('_less', np.intp),
    ('_greater', np.intp)
])


def transform_ckdtree(tree, transform):
    """
    Make transformation to a cKDTree object without rebuilding the tree.

    Parameters
    ----------
    tree: cKDTree object
        The original tree
    transform: dict {axis: transform_function, ...}
        The transform must be a monotonous mapping.
        e.g., transform={0: lambda x:2-x, 1: lambda x:x-3} defines a reflection over 1 for the first axis
        and a translation by 3 along the second axis.

    Outputs
    -------
    new_tree:
        The transformed tree.

    Tests
    -----
    from scipy.spatial import cKDTree

    transform = {0: lambda x: -x}

    a = np.random.randn(1000, 2)
    b = a.copy()
    for k in transform:
        b[:, k] = transform[k](b[:, k])

    c = np.random.randn(1000, 2)
    d = c.copy()
    for k in transform:
        d[:, k] = transform[k](d[:, k])

    t0 = cKDTree(a)
    t1 = transform_ckdtree(t0, transform)
    t2 = cKDTree(b)

    p0 = hstack(t0.query_ball_point(c, 1))
    p1 = hstack(t1.query_ball_point(d, 1))
    p2 = hstack(t2.query_ball_point(d, 1))
    assert (p0 == p1).all()
    assert (p0 == p2).all()
    """
    state = tree.__getstate__()
    ctree_data, data, n, m, leafsize, max, min, indices, boxsize, boxsize_data = state

    # The copies of ctree_data and data are returned, safe to write.
    # Copy max, min for write.
    max = max.copy()
    min = min.copy()

    if boxsize is not None:
        raise NotImplementedError

    ctree = np.frombuffer(bytearray(ctree_data), dtype=ckdtreenode_dtype)
    for axis in transform:
        func = transform[axis]
        new_max, new_min = func(max[axis]), func(min[axis])

        ix = ctree['split_dim'] == axis
        ctree['split'][ix] = func(ctree['split'][ix])
        data[:, axis] = func(data[:, axis])

        if new_max >= new_min:
            min[axis], max[axis] = new_min, new_max
        else:
            min[axis], max[axis] = new_max, new_min
            ctree['_less'][ix], ctree['_greater'][ix] = ctree['_greater'][ix], ctree['_less'][ix]

    new_state = ctree.tobytes(), data, n, m, leafsize, max, min, indices, boxsize, boxsize_data

    new_tree = cKDTree.__new__(cKDTree)
    new_tree.__setstate__(new_state)
    return new_tree