wkentaro
11/5/2015 - 5:38 PM

## cupy repeat working on

cupy repeat working on

``````# flake8: NOQA
# "flake8: NOQA" to suppress warning "H104  File contains nothing but comments"

# TODO(okuta): Implement tile

import numpy as np
import cupy

def repeat(a, repeats, axis=None):
if isinstance(repeats, int):
if axis is None:
b = a.reshape((-1, 1))
b = cupy.concatenate([b] * repeats, axis=1)
return b.reshape(-1)

repeats = [repeats] * a.shape[axis]
elif isinstance(repeats, list):
if not a.shape[axis] == len(repeats):
ValueError(
"'repeats' and 'axis' of 'a' should be same length: {} != {}"
.format(a.shape[axis], len(repeats)))
if axis is None:
raise ValueError(
"'axis' should be specified if 'repeats' is list")
else:
raise ValueError(
"'repeats' should be int or list: {0}".format(repeats))

a_r = cupy.rollaxis(a, axis)
b_r_shape = list(a_r.shape)
b_r_shape[0] = sum(repeats)
b_r = cupy.zeros((b_r_shape), dtype=a_r.dtype)
for i in xrange(a_r.shape[0]):
rep = repeats[i]
for j in xrange(rep):
b_r[sum(repeats[:i])+j] = a_r[i]
b = cupy.rollaxis(b_r, 0, axis+1)
return b

x = np.array([[1,2],[3,4]])
y = cupy.array([[1,2],[3,4]])
assert np.repeat(x, 2).tolist() == repeat(y, 2).tolist()
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()
assert np.repeat(x, [1,2], axis=0).tolist() == repeat(y, [1,2], axis=0).tolist()
assert np.repeat(x, [1,2], axis=1).tolist() == repeat(y, [1,2], axis=1).tolist()

x = np.arange(2*3*4).reshape((2,3,4))
y = cupy.arange(2*3*4).reshape((2,3,4))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()

x = np.arange(2*3*4*5).reshape((2,3,4,5))
y = cupy.arange(2*3*4*5).reshape((2,3,4,5))
assert np.repeat(x, 2, axis=1).tolist() == repeat(y, 2, axis=1).tolist()
``````