wkentaro
11/5/2015 - 5:38 PM

cupy repeat working on

cupy repeat working on

# flake8: NOQA
# "flake8: NOQA" to suppress warning "H104  File contains nothing but comments"

# TODO(okuta): Implement tile

import numpy as np
import cupy


def repeat(a, repeats, axis=None):
    if isinstance(repeats, int):
        if axis is None:
            b = a.reshape((-1, 1))
            b = cupy.concatenate([b] * repeats, axis=1)
            return b.reshape(-1)

        repeats = [repeats] * a.shape[axis]
    elif isinstance(repeats, list):
        if not a.shape[axis] == len(repeats):
            ValueError(
                "'repeats' and 'axis' of 'a' should be same length: {} != {}"
                .format(a.shape[axis], len(repeats)))
        if axis is None:
            raise ValueError(
                "'axis' should be specified if 'repeats' is list")
    else:
        raise ValueError(
            "'repeats' should be int or list: {0}".format(repeats))

    a_r = cupy.rollaxis(a, axis)
    b_r_shape = list(a_r.shape)
    b_r_shape[0] = sum(repeats)
    b_r = cupy.zeros((b_r_shape), dtype=a_r.dtype)
    for i in xrange(a_r.shape[0]):
        rep = repeats[i]
        for j in xrange(rep):
            b_r[sum(repeats[:i])+j] = a_r[i]
    b = cupy.rollaxis(b_r, 0, axis+1)
    return b


x = np.array([[1,2],[3,4]])
y = cupy.array([[1,2],[3,4]])
assert np.repeat(x, 2).tolist() == repeat(y, 2).tolist()
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()
assert np.repeat(x, [1,2], axis=0).tolist() == repeat(y, [1,2], axis=0).tolist()
assert np.repeat(x, [1,2], axis=1).tolist() == repeat(y, [1,2], axis=1).tolist()

x = np.arange(2*3*4).reshape((2,3,4))
y = cupy.arange(2*3*4).reshape((2,3,4))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()

x = np.arange(2*3*4*5).reshape((2,3,4,5))
y = cupy.arange(2*3*4*5).reshape((2,3,4,5))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()