Bitpack/bitserial_conv2d error with hls backend

Hi, I’m trying the bitserial operators in Relay and deploying on the aocl_sw_emu backend with llvm as the target host. My script looks like the following (mostly from a tutorial from here: https://github.com/jwfromm/Riptide/blob/master/notebooks/Relay/BitpackTest.ipynb)

import tvm
import numpy as np
import topi
from tvm import relay
import topi.testing
from tvm.contrib import graph_runtime
from topi.util import get_const_tuple

batch = 1
in_height = in_width = in_size = 32
in_dim = 32
out_dim = 32
in_channel = 32
num_filter = 32
kernel = 3
stride = (1, 1)
padding = (1, 1)
activation_bits = 1
weight_bits = 1
unipolar = True

input_dtype = 'uint8'
out_dtype = 'int8'

def generate_quantized_np(shape, bits, out_dtype):
    min_val = 0 
    max_val = 1 << bits
    return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)

with tvm.target.create('llvm'):
    #A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
    #W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
    #QW = topi.nn.bitpack(W, weight_bits, pack_axis=1, bit_axis=0, pack_type='uint8')
    
    A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
    #W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
    W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
    
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)

a_np = generate_quantized_np(a_shape, activation_bits, input_dtype)
w_np = generate_quantized_np(w_shape, weight_bits, input_dtype)

if unipolar:
    w_ = np.copy(w_np).astype(out_dtype)
    for x in np.nditer(w_, op_flags=['readwrite']):
        x[...] = 1 if x == 1 else -1
    #b_np = topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding)
    b_np = topi.testing.conv2d_nhwc_python(a_np.astype(out_dtype), w_, stride, padding)
else:
    b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
    

input_var = relay.var('input', shape=A.shape, dtype=A.dtype)
kernel_var = relay.var('kernel', shape=W.shape, dtype=W.dtype)
q_kernel = relay.nn.bitpack(kernel_var, bits=1, pack_axis=2, bit_axis=4, pack_type=input_dtype)
q_out = relay.nn.bitserial_conv2d(input_var, q_kernel, channels=32, kernel_size=(3,3), padding=(1, 1), data_layout='NHWC', pack_dtype='uint8', out_dtype='int8', kernel_layout="HWIO")

q_func = relay.Function([input_var, kernel_var], q_out)

target='aocl_sw_emu'
ctx = tvm.context(target, 0)
#target='llvm'
#ctx = tvm.cpu()

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(q_func, target=target, target_host='llvm', params={'kernel': w_np})

module = graph_runtime.create(graph, lib, ctx)
module.set_input('input', a_np)
module.set_input(**params)
module.run()

import pdb; pdb.set_trace()
output = module.get_output(0).asnumpy()
tvm.testing.assert_allclose(output, b_np, rtol=1e-5)

This works fine with the CPU/x86 backend, but when I change it to hls/aocl_sw_emu it seems to be complaining about the output dimensions:

** WARNING: [acls10mx_ref0] NOT using DMA to transfer 32768 bytes from host to device because of lack of alignment
**                 host ptr (0x2e23050) and/or dev offset (0x0) is not aligned to 64 bytes
Traceback (most recent call last):

  File "tut.py", line 73, in <module>
    module.run()

  File "/home/chungs31/repos/Riptide/tvm/python/tvm/contrib/graph_runtime.py", line 176, in run
    self._run()

  File "/home/chungs31/repos/Riptide/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(tvm::runtime::GraphRuntime::Run()+0x5e) [0x7fea69c51416]
  [bt] (7) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(std::function<void ()>::operator()() const+0x32) [0x7fea69194692]
  [bt] (6) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(+0x22e5bac) [0x7fea69c57bac]
  [bt] (5) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(+0x22e31b9) [0x7fea69c551b9]
  [bt] (4) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x3d) [0x7fea692ebeff]
  [bt] (3) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x6d) [0x7fea692ec473]
  [bt] (2) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(+0x227ecbc) [0x7fea69bf0cbc]
  [bt] (1) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(+0x227d910) [0x7fea69bef910]
  [bt] (0) /home/chungs31/repos/Riptide/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x25) [0x7fea69190907]
  File "/home/chungs31/repos/Riptide/tvm/src/runtime/library_module.cc", line 91
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (1 == int32(arg2.shape[3])), Argument arg2.shape[3] has an unsatisfied constraint

Is there something wrong with my shapes that doesn’t show up as an error on x86? If not, how would I go about debugging this? I’m thinking it could be something related to bitpacking, perhaps constraints not changing post-bitpack…

Would appreciate any directions.