Hi Team,
I am briefly mentioning the one of the example I tried,
n = te.var(‘n’)
m = te.var(‘m’)
C= te.placeholder((n), name=‘C’)
A = te.placeholder((n, m),dtype=‘float32’, name=‘A’)
k = te.reduce_axis((0, m), name=‘k’)
comp = lambda a, b: a + b
init = lambda dtype : tvm.tir.const(0,dtype=dtype) ; This also accepts functions min and max
product = te.comm_reducer(comp, init)
B = te.compute((n,), lambda i: product(A[i, k],axis=k), name=‘B’)
s = te.create_schedule(B.op)
print(tvm.lower(s, [A, B],simple_mode=True))
Gives output
produce B {
for (i, 0, n) {
B[(i*stride)] = 0f
for (k, 0, m) {
B[(i*stride)] = (B[(i*stride)] + A[((i*stride) + (k*stride))])
}
}
}
Instead of 0f in the above example i am trying to initialize with C[i] by changing the init statement
init = lambda i: (C[i],‘float32’)
some part of Debug output
File “…/…/check/reduce.py”, line 64, in B = te.compute((n,), lambda i: product(A[i, k],axis=k), name=‘B’)
File “/local/mnt/workspace/gpk/nnv3/tvm/python/tvm/tir/op.py”, line 937, in reducer return _make_reduce(expr, axis, where)
File “/local/mnt/workspace/gpk/nnv3/tvm/python/tvm/tir/op.py”, line 924, in _make_reduce id_elem = convert(id_elem)
File “/local/mnt/workspace/gpk/nnv3/tvm/python/tvm/runtime/object_generic.py”, line 98, in convert return convert_to_object(value)
And gives the following error
[bt] (1) /local/mnt/workspace/gpk/nnv3/tvm/install/lib/libtvm.so(tvm::tir::CallNode::make(tvm::runtime::DataType, std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator >, tvm::Array<tvm::PrimExpr, void>, tvm::tir::CallNode::CallType, tvm::tir::FunctionRef, int)+0x2f4) [0x7fd8c427a394] [bt] (0) /local/mnt/workspace/gpk/nnv3/tvm/install/lib/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3f) [0x7fd8c3d53a7f] File “/local/mnt/workspace/gpk/nnv3/tvm/src/tir/ir/expr.cc”, line 258 TVMError: Check failed: args[i].dtype().is_int():
Also, I doubt I am missing some information that needs to passed on compute statement with respect to the init statement change.
init = lambda i,dtype (C[i],dtype=dtype) Gives invalid syntax error
Also tried just if a variable ‘p’ can be initialized as well.
- comm_reducer function can handle only constant values as initial values?
- Changing the convert() function inside make_reduce in python/tvm/tir/op.py will it allow any array to passed instead of constant value in comm_reducer?
Please let me know your suggestions on this.
Regards, Gayatri P K