Conv2d_transpose kernel 2x2, strides (2,2) fails for CUDA - Cannot prove

I found that conv2d_transpose op fails when kernel size is 2x2 and strides are (2,2). Errors:

tvm/src/pass/loop_partition.cc:544: Cannot prove: ((((floordiv((((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) + 63), 64) - 1) - (((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) - select((-63 <= ((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2)))), (floordiv((((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) + 63), 64)*63), 0))) + 1) >= 0), when generating the post doubt loop

  File "/root/workplace/tvm/src/pass/split_host_device.cc", line 135
TVMError: Check failed: !use_count_.count(v): variable dh has been used before definition!
During handling of the above exception, another exception occurred:

To reproduce the error

import tvm
from tvm import relay
import tensorflow as tf
input_tensor = "input_1"
# NHWC
input_shape=(1,16,16,8)
x = tf.compat.v1.placeholder(tf.float32, shape=input_shape, name=input_tensor)
# HWOI
w2 = tf.compat.v1.placeholder(tf.float32, shape=(2,2,3,8))
out_shape = tf.compat.v1.placeholder(tf.int32, shape=(4))
deconv = tf.compat.v1.nn.conv2d_transpose(x, w2, out_shape, (1,2,2,1), padding='VALID')

sess = tf.compat.v1.Session()
graph_def = sess.graph_def

mod, params = relay.frontend.from_tensorflow(graph_def, layout='NCHW', shape={input_tensor: input_shape})

target = tvm.target.cuda()
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_70')

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod, target, params=params)

I tried other targets - llvm and arm_cpu - they are working fine. Only cuda fails.

Related to PR 4243

@tqchen @yongwww @kevinthesun @yzhliu @ZihengJiang @kimishpatel @umangyadav @optima2005 @Huyuwei

I can reproduce this issue. Many thanks for finding this! May I suggest you to raise an issue in the github to track it, or you would like I would create an issue with quotation to above testing code. Please advice, thanks!

Opened an Issue https://github.com/apache/incubator-tvm/issues/4470

The same issue was reported earlier Compile error for CUDA target

I checked more kernel and strides combinations and found that the error happens when kernel is equal to strides, e.g.

# kernel and strides when compilation for CUDA fails
2x2 and (2,2)
3x3 and (3,3)
4x4 and (4,4)
5x5 and (5,5)
2x3 and (2,3)
3x2 and (3,2)
1x2 and (1x2)
etc

I also found that conv2d_transpose CUDA compilation fails if output channel is 1. Regardless of kernel and strides size. It fails even for kernel 1x1 and strides (1,1).

To reproduce

import tvm
from tvm import relay
import tensorflow as tf
input_tensor = "input_1"
input_shape=(1,16,16,8)
x = tf.compat.v1.placeholder(tf.float32, shape=input_shape, name=input_tensor)
w2 = tf.compat.v1.placeholder(tf.float32, shape=(1,1,1,8))
out_shape = tf.compat.v1.placeholder(tf.int32, shape=(4))
deconv = tf.compat.v1.nn.conv2d_transpose(x, w2, out_shape, (1,1,1,1), padding='VALID')

sess = tf.compat.v1.Session()
graph_def = sess.graph_def

mod, params = relay.frontend.from_tensorflow(graph_def, layout='NCHW', shape={input_tensor: input_shape})

target = tvm.target.cuda()
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_70')

with relay.build_config(opt_level=3):
  graph, lib, params = relay.build(mod, target, params=params)

Error:

  File "/root/workplace/tvm/src/pass/make_api.cc", line 204
TVMError: Not all Vars are passed in api_args:  'threadIdx.z'  does not appear in api_args
During handling of the above exception, another exception occurred:

TVMError: Not all Vars are passed in api_args:  'threadIdx.z'  does not appear in api_args
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 8, 16, 16), float32], %p1: Tensor[(8, 1, 1, 1), float32], Primitive=1) -> Tensor[(1, 1, 16, 16), float32] {
  nn.conv2d_transpose(%p0, %p1, channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 16, 16), float32] */
}

Temporary workaround fix is PR 4472

@apivovarov I get the same error but with kernel (4,4) and strides (2,2). Even after merging your workaround the error is till there. Any idea?