我们通常希望在单个循环内计算具有相同维度的多个输出,或者,执行涉及argmax等多个值的缩减。
在这篇教程,我们将介绍在TVM中元组输入。
from __future__ import absolute_import, print_function
import tvm
import numpy as np
对于具有相同维度的运算符,如果我们希望它们在下一个策略程序中一起调度,我们可以将它们放在一起作为tvm.compute的输入。
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m,n),name='A0')
A1 = tvm.placeholder((m,n),name='A1')
B0, B1 = tvm.compute((m,n),lambda i,j:(A0[i,j]+2,A1[i,j]*3),name='B')
#生成IR中间表示代码
s = tvm.create_schedule(B0.op)
print(tvm.lower(s, [A0,A1,B0,B1],simple_mode=True))
有时,我们需要多个输入来表示一些Reduction算子,输入将协同工作,例如argmax。在Reduction过程中,argmax需要比较操作数的值,也需要去保存操作数的索引。这能使用comm_reducer来表示:
# xy是Reduction的操作数,他们是索引和值的元组
def fcombine(x,y):
lhs = tvm.expr.Select((x[1]>=y[1]),x[0],y[0])
rhs = tvm.expr.Select((x[1]>=y[1]),x[1],y[1])
return lhs,rhs
# 标识元素也需要是一个元组,所以‘fidentity’接受两种类型数据作为输入
def fidentity(t0,t1):#t0,t1为类型dtype
return tvm.const(-1,t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')
#描述Reduction计算
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m,n),name='idx',dtype='int32')
val = tvm.placeholder((m,n),name='val',dtype='int32')
k = tvm.reduce_axis((0,n), 'k')
T0, T1 = tvm.compute((m, ), lambda i: argmax(idx[i,k], val(i,k), axis=k),name='T')
#生成IR代码
s = tvm.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
教程介绍元组输入算子的使用: