虽然TVM支持透明代码生成,但有时将手动编写的代码合并到管道中也很有帮助。例如,我们想去为部分卷积和使用cuDNN和定义其他阶段。
TVM原生支持黑盒函数调用。TVM支持兼容DLPack的所有张量函数。这意味着我们可以使用POD类型(pointer,int,float)或指向DLTensor的指针作为参数调用任何函数。
from __future__ import absolute_import, print_function
import tvm
import numpy as np
from tvm.contrib import cblas
在下面的示例中,我们使用tvm.extern添加外部数组函数调用。在外部调用中,我们声明输出张量的形状。在第二个参数中,我们提供输入列表。
用户需要提供描述如何计算结果的函数。计算函数获取输入的符号占位符(placeholder)列表,输出的符号占位符列表,并返回执行语句。
在这种情况下,我们只需调用一个注册的TVM函数,该函数调用CBLAS库。 TVM不控制外部数组函数的内部并将其视为黑盒子。我们可以进一步混合可策略的TVM调用,为结果添加bias项。
n = 1024
l = 128
m = 235
bias = tvm.var('bias',dtype=tvm.float32)
#定义输入
A = tvm.placeholder((n,l),name='A')
B = tvm.placeholder((l,m),name='B')
#矩阵乘法(调用外部张量函数)
C = tvm.extern((n,m),[A,B],lambda ins,outs:tvm.call_packed("tvm.contrib.cblas.matmul",ins[0],ins[1],outs[0],False,False),nams="C")
#偏置加
D = tvm.compute(C.shape, lambda i, j: C[i,j]+bias,name="D")
s=tvm.create_schedule(D.op)
ctx = tvm.cpu(0)
f = tvm.build(s, [A,B,D,bias], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,l)).astype(A.dtype),ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
tvm.testing.assert_allclose(d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy())+10, rtol=1e-5)
下面例子等同于前面的例子
from tvm.contrib import cblas
C = cblas.matmul(A,B)
D = tvm.compute(C.shape, lambda i,j: C[i,j]+bias,name="D")
s = tvm.create_schedule(D.op)
在TVM能够调用任何包函数(PackedFunc)。我们能用外部函数去调用python。
以下示例,将python函数注册到TVM运行时系统,并使用它来完成计算的一个阶段。这使得TVM更加灵活。例如,我们可以插入前端回调函数来检查中间结果或将自定义代码与TVM混合。
@tvm.register_func("tvm.contrib.my_tvm_addone")
def my_tvm_addon(x,y):
print("my_tvm_addone signatures: %s, %s" % (type(x), type(y)))
tvm.nd.array(x.asnumpy() + 1).copyto(y)
A = tvm.placeholder((n,), name='A')
#调用上面注册的函数来计算
B = tvm.extern(A.shape, [A], lambda ins, outs: tvm.call_packed(
"tvm.contrib.my_tvm_addone", ins[0], outs[0]), name="C")
s = tvm.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1, rtol=1e-5)