FP16 Mixed Precision Reduction with Result cast

I have been attempting to write a single Cuda kernel reduction that takes FP16 inputs cast the inputs to FP32 for the reduction and then saves the results back into FP16. Is this possible currently in a single kernel with Tensor Expressions? The error I see is that when I try to add a cast to the result I get an error suggesting that the last operation has to be the reduction. I have also tried adding another compute line to the schedule but I get the same error.

TVMError: Check failed: 0 == level_: Reductions are only allowed at the top level of compute. Please create another tensor for further composition.

Example:

import tvm

tgt_host="llvm"
tgt="cuda"

toks         = tvm.var("tokens")
hidden       = tvm.const(1024)
inputs       = tvm.placeholder((toks, hidden), name='inputs', dtype='float16')
y            = tvm.reduce_axis((0, hidden), "y")
outputs      = tvm.compute((toks,), lambda x : tvm.sum(inputs[x][y].astype('float32'), axis=y).astype('float16'),  name='outputs')
sched        = tvm.create_schedule([outputs.op])

sched[outputs].bind(outputs.op.axis[0], tvm.thread_axis("blockIdx.x"))
sched[outputs].bind(outputs.op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))

As the error said, “Reductions are only allowed at the top level of compute.”, cast to fp16 can’t not be added to the reduction compute

Thanks for the response! I am new to TVM some not sure if I am aware of the best practices and limitations. If you want to combine more operations than just a reduction into a kernel it sounds like TOPI might fuse operations. Does this limitation exist there as well where a reduction can only be the last operation? If not is this an enhancement that needs to be asked for? This pattern is seen with normalizations and softmax.

In this case you can you multiple compute op, for example you can see the softmax implementation in topi.
Elemwise ops cannot be fused into the reduction body because they don’t share the same loop.

@vinx13 will this be possible in the future? It’d be nice to convert the reduction axis to a regular iteration axis and move two computations under the same root. For example, in softmax, it’d be nice to put sum and exp in the same loop.

@tqchen might have more thoughts on this