@lsy643 I get the latest TVM code and error.
My error is in another function (_stridedSlice) , is not the _slice.
def _stridedSlice():
def _impl(inputs, attr, params, mod):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = _get_list_param(params, inputs[1])
end = _get_list_param(params, inputs[2])
stride = _get_list_param(params, inputs[3])
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)
final_index += 1
return m_begin, m_end, m_stride, fshape_indices
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
#begin from -1 to 3 (-1 -> 3) ;end 0;stride=1
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
#error
out_shape = _infer_shape(out, mod=mod) <-----------error
if not fshape_indices:
fshape_indices = range(len(out_shape))
#Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])
if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))
return _impl
error info:
an internal invariant was violated while typechecking your program
Check failed: begin_v < end_v (3 vs. 0) : strided_slice get empty slice at axis 0
......
File xxxxxx in _impl
out_shape = _infer_shape(out, mod=mod)
File xxxxx in infer_shape
out_type = infer_type(inputs, mod=mod)
File xxxxx in infer_type
new_mod = IRModule.from_expr(node)
Traceback:
[bt] (8) tvm::IRModuleNode::FromExpr(xxxx)
[bt] (7) tvm::IRModuleNode::Add(xxx)
[bt] (6) tvm::RunTypeCheck(xx)
[bt] (5) tvm::relay::inferType(xxx)
[bt] (4) tvm::relay::TypeInference::infer(xxx)
[bt] (3) tvm::relay::TypeSolver::solve(xxx)