CUDA Injective schedule thread count is not always optimal

In CUDA’s default injective schedule, the thread count is set to tvm.target.current_target(allow_none=False).max_num_threads. This leads to a huge thread count per block (1024 in the case of the M60 GPU I am testing on). This is not optimal for all ops.

For example, I am currently testing a softmax with input shape (10, 12, 512, 512), and found that 64 threads per block saved over 10ms.

On the flip side, I have found other ops that benefit from having this large thread count.

What do you think is the best way to resolve this? Can we auto-tune this value per op that uses the injective schedule? Should we allow ops to pass in their ideal thread count?

I feel like the right idea would be to parameterize and auto-tune thread count for ops that use the default injective schedule. Is there a good way to do this? Otherwise, we can just copy the injective schedule into softmax.py and do the auto-tuning there, but that’s less than ideal.

@vinx13 @masahi do you have any thoughts?

There is no deep reason for why we use 1024 threads per block on CUDA. I’m +1 for making this number tunable, as long as the default is 1024 (to avoid perf regression).

if the default injective schedule is slow for some particular ops, it would be good to have a tunable schedule

@vinx13 I finally got around to trying to create a tunable schedule for CUDA softmax, but am seeing a strange error when I try to tune it: “cannot find workload in attribute of this schedule”.

My branch is here. I added the necessary lines to topi_integration and relay_integration, and updated nn/softmax.py to have the compute decorated with @tvm.target.generic_func. I then updated cuda/softmax.py to decorate the schedule with:

@autotvm.register_topi_schedule(generic.schedule_softmax, ["cuda", "gpu"], "direct")

I also added this to the top of the file:

autotvm.register_topi_compute(nn.softmax, ["cuda", "gpu"], "direct", nn.softmax.fdefault).

From some initial investigation, the functions that are registered in register_topi_compute seem to never be called (like config_dispatcher and template_call), leading to the workload attribute never being set. However, I can’t figure out why they’re not be called.

Do you have any ideas?

You are right that you need to call register_topi_compute but it is weird that the function is never called. I will take a look on it

@vinx13 did you have a chance to take a look?

This is a issue with compute registration. Although you declare softmax as generic func, the compute for softmax in Relay is defined in C++ side. So the python generic func is never invoked. See my patch below

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 54f13c68..9801d4a9 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -31,6 +31,11 @@ reg.register_schedule("nn.relu", schedule_injective)
 reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
 
 # softmax
+@reg.register_compute("nn.softmax")
+def compute_softmax(attrs, inputs, out_type, target):
+    """Compute definition of softmax"""
+    return [topi.nn.softmax(inputs[0])]
+
 @reg.register_schedule("nn.softmax")
 def schedule_softmax(_, outputs, target):
     """Schedule definition of softmax"""
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index d3a71787..5e4f5cfa 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -348,7 +348,8 @@ RELAY_REGISTER_OP("nn.softmax")
 .set_num_inputs(1)
 .add_argument("data", "Tensor", "The input tensor.")
 .set_support_level(1)
-.add_type_rel("Identity", IdentityRel)
+.add_type_rel("Identity", IdentityRel);
+/*
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                          const Array<Tensor>& inputs,
                                          const Type& out_type,
@@ -357,7 +358,7 @@ RELAY_REGISTER_OP("nn.softmax")
   CHECK(param != nullptr);
   return Array<Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
 });
-
+*/
 
 // relay.nn.log_softmax
 TVM_REGISTER_API("relay.op.nn._make.log_softmax")

Ah, that makes sense :slight_smile: thanks for investigating!