I’m afraid I can only share some code snippets here. It looks like it’s the codegen of vectorize(). The related python snippet is
s[FS_1].compute_at(s[OL], xoicc)
h1, w1, i1, o1 = s[FS_1].op.axis
io = s[FS_1].fuse(i1, o1)
io, iox = s[FS_1].split(io, factor=num_thread_x * 4)
ioy, io = s[FS_1].split(io, nparts=num_thread_y)
iox, io4 = s[FS_1].split(iox, factor=4)
s[FS_1].reorder(h1, w1, io, ioy, iox, io4)
s[FS_1].bind(iox, thread_x)
s[FS_1].bind(ioy, thread_y)
s[FS_1].vectorize(io4)
With this code, the corresponding CUDA snippet is
for (int ax2_ax3_fused_outer_inner = 0; ax2_ax3_fused_outer_inner < 4; ++ax2_ax3_fused_outer_inner) {
float4 _1;
int4 _2 = make_int4((((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)));
int4 _3 = make_int4(256, 256, 256, 256);
int4 _4 = make_int4(0, 0, 0, 0);
uint4 _5;
_5.x = (_3.x>=_4.x);
_5.y = (_3.y>=_4.y);
_5.z = (_3.z>=_4.z);
_5.w = (_3.w>=_4.w);
int4 _6 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _7 = make_int4(256, 256, 256, 256);
int4 _8;
_8.x = (_6.x%_7.x);
_8.y = (_6.y%_7.y);
_8.z = (_6.z%_7.z);
_8.w = (_6.w%_7.w);
int4 _9 = make_int4(0, 0, 0, 0);
uint4 _10;
_10.x = (_8.x>=_9.x);
_10.y = (_8.y>=_9.y);
_10.z = (_8.z>=_9.z);
_10.w = (_8.w>=_9.w);
uint4 _11;
_11.x = (_5.x&&_10.x);
_11.y = (_5.y&&_10.y);
_11.z = (_5.z&&_10.z);
_11.w = (_5.w&&_10.w);
int4 _12 = make_int4(256, 256, 256, 256);
int4 _13 = make_int4(0, 0, 0, 0);
uint4 _14;
_14.x = (_12.x<_13.x);
_14.y = (_12.y<_13.y);
_14.z = (_12.z<_13.z);
_14.w = (_12.w<_13.w);
int4 _15 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _16 = make_int4(256, 256, 256, 256);
int4 _17;
_17.x = (_15.x%_16.x);
_17.y = (_15.y%_16.y);
_17.z = (_15.z%_16.z);
_17.w = (_15.w%_16.w);
int4 _18 = make_int4(0, 0, 0, 0);
uint4 _19;
_19.x = (_17.x<=_18.x);
_19.y = (_17.y<=_18.y);
_19.z = (_17.z<=_18.z);
_19.w = (_17.w<=_18.w);
uint4 _20;
_20.x = (_14.x&&_19.x);
_20.y = (_14.y&&_19.y);
_20.z = (_14.z&&_19.z);
_20.w = (_14.w&&_19.w);
uint4 _21;
_21.x = (_11.x||_20.x);
_21.y = (_11.y||_20.y);
_21.z = (_11.z||_20.z);
_21.w = (_11.w||_20.w);
int4 _22 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _23 = make_int4(256, 256, 256, 256);
int4 _24;
_24.x = (_22.x/_23.x);
_24.y = (_22.y/_23.y);
_24.z = (_22.z/_23.z);
_24.w = (_22.w/_23.w);
int4 _25 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _26 = make_int4(256, 256, 256, 256);
int4 _27;
_27.x = (_25.x/_26.x);
_27.y = (_25.y/_26.y);
_27.z = (_25.z/_26.z);
_27.w = (_25.w/_26.w);
int4 _28 = make_int4(1, 1, 1, 1);
int4 _29;
_29.x = (_27.x-_28.x);
_29.y = (_27.y-_28.y);
_29.z = (_27.z-_28.z);
_29.w = (_27.w-_28.w);
int4 _30 = _21 ? _24 : _29;
int4 _31 = make_int4(512, 512, 512, 512);
int4 _32;
_32.x = (_30.x*_31.x);
_32.y = (_30.y*_31.y);
_32.z = (_30.z*_31.z);
_32.w = (_30.w*_31.w);
int4 _33;
_33.x = (_2.x+_32.x);
_33.y = (_2.y+_32.y);
_33.z = (_2.z+_32.z);
_33.w = (_2.w+_32.w);
int4 _34 = make_int4(((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256));
int4 _35;
_35.x = (_33.x+_34.x);
_35.y = (_33.y+_34.y);
_35.z = (_33.z+_34.z);
_35.w = (_33.w+_34.w);
int4 _36 = make_int4(256, 256, 256, 256);
int4 _37 = make_int4(0, 0, 0, 0);
uint4 _38;
_38.x = (_36.x>=_37.x);
_38.y = (_36.y>=_37.y);
_38.z = (_36.z>=_37.z);
_38.w = (_36.w>=_37.w);
int4 _39 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _40 = make_int4(256, 256, 256, 256);
int4 _41;
_41.x = (_39.x%_40.x);
_41.y = (_39.y%_40.y);
_41.z = (_39.z%_40.z);
_41.w = (_39.w%_40.w);
int4 _42 = make_int4(0, 0, 0, 0);
uint4 _43;
_43.x = (_41.x>=_42.x);
_43.y = (_41.y>=_42.y);
_43.z = (_41.z>=_42.z);
_43.w = (_41.w>=_42.w);
uint4 _44;
_44.x = (_38.x&&_43.x);
_44.y = (_38.y&&_43.y);
_44.z = (_38.z&&_43.z);
_44.w = (_38.w&&_43.w);
int4 _45 = make_int4(256, 256, 256, 256);
int4 _46 = make_int4(0, 0, 0, 0);
uint4 _47;
_47.x = (_45.x<_46.x);
_47.y = (_45.y<_46.y);
_47.z = (_45.z<_46.z);
_47.w = (_45.w<_46.w);
int4 _48 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _49 = make_int4(256, 256, 256, 256);
int4 _50;
_50.x = (_48.x%_49.x);
_50.y = (_48.y%_49.y);
_50.z = (_48.z%_49.z);
_50.w = (_48.w%_49.w);
int4 _51 = make_int4(0, 0, 0, 0);
uint4 _52;
_52.x = (_50.x<=_51.x);
_52.y = (_50.y<=_51.y);
_52.z = (_50.z<=_51.z);
_52.w = (_50.w<=_51.w);
uint4 _53;
_53.x = (_47.x&&_52.x);
_53.y = (_47.y&&_52.y);
_53.z = (_47.z&&_52.z);
_53.w = (_47.w&&_52.w);
uint4 _54;
_54.x = (_44.x||_53.x);
_54.y = (_44.y||_53.y);
_54.z = (_44.z||_53.z);
_54.w = (_44.w||_53.w);
int4 _55 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _56 = make_int4(256, 256, 256, 256);
int4 _57;
_57.x = (_55.x%_56.x);
_57.y = (_55.y%_56.y);
_57.z = (_55.z%_56.z);
_57.w = (_55.w%_56.w);
int4 _58 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
int4 _59 = make_int4(256, 256, 256, 256);
int4 _60;
_60.x = (_58.x%_59.x);
_60.y = (_58.y%_59.y);
_60.z = (_58.z%_59.z);
_60.w = (_58.w%_59.w);
int4 _61 = make_int4(256, 256, 256, 256);
int4 _62;
_62.x = (_60.x+_61.x);
_62.y = (_60.y+_61.y);
_62.z = (_60.z+_61.z);
_62.w = (_60.w+_61.w);
int4 _63 = _54 ? _57 : _62;
int4 _64;
_64.x = (_35.x+_63.x);
_64.y = (_35.y+_63.y);
_64.z = (_35.z+_63.z);
_64.w = (_35.w+_63.w);
_1.x = Conv2dFilter_1[_64.x];
_1.y = Conv2dFilter_1[_64.y];
_1.z = Conv2dFilter_1[_64.z];
_1.w = Conv2dFilter_1[_64.w];
((__shared__ float4*)(Conv2dFilter_1_shared + (((((int)threadIdx.y) * 512) + (ax2_ax3_fused_outer_inner * 128)) + (((int)threadIdx.x) * 4))))[0] = _1;
}
As mentioned above, the error occurs at two lines:
int4 _30 = _21 ? _24 : _29;
and
int4 _63 = _54 ? _57 : _62;
It’s confused how a uint4 variable like _21 and _54 is recognized as true/false here. All elements are true? One of them is true?
If I comment out vectorize(), the corresponding CUDA snippet becomes
for (int ax2_ax3_fused_outer_inner = 0; ax2_ax3_fused_outer_inner < 4; ++ax2_ax3_fused_outer_inner) {
for (int ax2_ax3_fused_inner_inner = 0; ax2_ax3_fused_inner_inner < 4; ++ax2_ax3_fused_inner_inner) {
Conv2dFilter_1_shared[((((((int)threadIdx.y) * 512) + (ax2_ax3_fused_outer_inner * 128)) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner)] = Conv2dFilter_1[((((((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)) + (((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner) >> 8) * 512)) + ((((int)blockIdx.x) & 1) * 256)) + ((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner) & 255))];
}
}
which looks quite normal.