cupy repeat working on
# flake8: NOQA
# "flake8: NOQA" to suppress warning "H104 File contains nothing but comments"
# TODO(okuta): Implement tile
import numpy as np
import cupy
def repeat(a, repeats, axis=None):
if isinstance(repeats, int):
if axis is None:
b = a.reshape((-1, 1))
b = cupy.concatenate([b] * repeats, axis=1)
return b.reshape(-1)
repeats = [repeats] * a.shape[axis]
elif isinstance(repeats, list):
if not a.shape[axis] == len(repeats):
ValueError(
"'repeats' and 'axis' of 'a' should be same length: {} != {}"
.format(a.shape[axis], len(repeats)))
if axis is None:
raise ValueError(
"'axis' should be specified if 'repeats' is list")
else:
raise ValueError(
"'repeats' should be int or list: {0}".format(repeats))
a_r = cupy.rollaxis(a, axis)
b_r_shape = list(a_r.shape)
b_r_shape[0] = sum(repeats)
b_r = cupy.zeros((b_r_shape), dtype=a_r.dtype)
for i in xrange(a_r.shape[0]):
rep = repeats[i]
for j in xrange(rep):
b_r[sum(repeats[:i])+j] = a_r[i]
b = cupy.rollaxis(b_r, 0, axis+1)
return b
x = np.array([[1,2],[3,4]])
y = cupy.array([[1,2],[3,4]])
assert np.repeat(x, 2).tolist() == repeat(y, 2).tolist()
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()
assert np.repeat(x, [1,2], axis=0).tolist() == repeat(y, [1,2], axis=0).tolist()
assert np.repeat(x, [1,2], axis=1).tolist() == repeat(y, [1,2], axis=1).tolist()
x = np.arange(2*3*4).reshape((2,3,4))
y = cupy.arange(2*3*4).reshape((2,3,4))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()
x = np.arange(2*3*4*5).reshape((2,3,4,5))
y = cupy.arange(2*3*4*5).reshape((2,3,4,5))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()