[Relay] conv2d_transpose doesn't handle output_padding correctly

The output_padding parameter just pads the topi output with zeros.

I am trying to get conv2d_transpose to work in the context of a gradient of conv2d. Because of this I get output that has the right shape, but bad values (zeros) around the edges in some cases. This happens when the forward code does striding.

I also looked at the registered gradient code for nn.conv2d and it has the same issue for the input gradient.

Maybe I am doing something wrong here, but I think that just padding with zeros is the wrong thing to do. None of the other frameworks that I checked does this.

Can you explain what is the desired behavior?
Padding with zero should be correct. For gradient, it is mainly used for non-divisible strides. In this case, we pad the conv grad with zeros

Sorry for the late reply.

I made a script that runs a conv2d_transpose with pytorch and relay to compare:

import numpy as np

strides = (2, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
output_padding = (1, 0)
kernel_size = (2, 2)
in_channels = 1

grad_shp = (1, 1, 2, 3)
w_shp = (1, 1, 2, 2)

expected_out_shp = (1, 1, 3, 2)

grad_val = np.random.rand(*grad_shp)
w_val = np.random.rand(*w_shp)

## TORCH ##                                                                     

import torch

grad_t = torch.tensor(grad_val)
w_t = torch.tensor(w_val)
torch_out = torch.nn.functional.conv_transpose2d(
    grad_t, w_t,
    stride=strides,
    padding=padding,
    dilation=dilation,
    groups=groups,
    output_padding=output_padding).numpy()

print("Torch result:\n", torch_out)

## RELAY ##                                                                     

import tvm
from tvm import relay

grad_c = relay.const(grad_val)
w_c = relay.const(w_val)
out_node = relay.nn.conv2d_transpose(
    grad_c, w_c,
    strides=strides,
    padding=padding,
    dilation=dilation,
    groups=groups,
    output_padding=output_padding,
    kernel_size=kernel_size,
    channels=in_channels)

ctx = tvm.ndarray.context('cpu', 0)
mod = relay.Module({})
exec = relay.create_executor(mod=mod, ctx=ctx, target='llvm')

relay_out = exec.evaluate(out_node)

print("Relay result:\n", relay_out)

Running this script gets this output:

Torch result:
 [[[[0.10655919 0.26356774]
   [0.3575391  0.16530708]
   [0.98032029 0.42839234]]]]

... Some tvm debug output ...

Relay result:
 [[[[0.10655919 0.26356772]
   [0.35753912 0.16530707]
   [0.         0.        ]]]]

I would expect that the last row in the relay matrix would not be just zeroes and correspond to the torch output. This is the padding that I refer to.

Thanks for the report. I’ll look into it

Finally I had some time so I made a PR to fix the problem: https://github.com/apache/incubator-tvm/pull/4318

Could you review it (or tag people that should review it)?

Isn’t this an issue we need to tag rather than just a pull request ?

I don’t know what you mean. I made a fix for that issue so it’s a PR.