@@ -44,8 +44,7 @@ void add_slice_tensor_copy_node(
44
44
vTensorPtr t_in = graph.get_tensor (in);
45
45
vTensorPtr t_out = graph.get_tensor (out);
46
46
47
- VK_CHECK_COND (check_packed_dim_is (*t_in, WHCN::kChannelsDim ));
48
- VK_CHECK_COND (check_packed_dim_is (*t_out, WHCN::kChannelsDim ));
47
+ VK_CHECK_COND (check_same_packed_dim (*t_in, *t_out));
49
48
50
49
// Need normalize the dim
51
50
int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
@@ -76,7 +75,13 @@ void add_slice_tensor_copy_node(
76
75
start = normalize_idx (start, in_sizes[dim], 0 );
77
76
end = normalize_idx (end, in_sizes[dim], in_sizes[dim]);
78
77
79
- if (dim_index == kChannel4D ) {
78
+ const vkapi::SpecVarList spec_vars = {t_in->packed_dim ()};
79
+
80
+ const auto packed_dim_idx =
81
+ static_cast <DimIndex>(DimIndex::DIM_LAST - t_in->packed_dim ());
82
+
83
+ // if slice dim is the same as the packed dim, we can use the channel slice
84
+ if (dim_index == packed_dim_idx) {
80
85
// slice by channel
81
86
std::string kernel_name = " slice_channel" ;
82
87
kernel_name.reserve (kShaderNameReserve );
@@ -99,26 +104,31 @@ void add_slice_tensor_copy_node(
99
104
{in, vkapi::MemoryAccessType::READ}},
100
105
{t_out->sizes_ubo (),
101
106
t_in->sizes_ubo (),
102
- graph.create_params_buffer (params)}));
107
+ graph.create_params_buffer (params)},
108
+ spec_vars));
103
109
104
110
} else {
105
111
// GPU's coordinate is in x, y, z
106
112
int64_t gpu_dim = -1 ;
107
- int64_t stride = 1 ;
113
+ int64_t in_channel_stride = 1 ;
108
114
if (dim_index == kWidth4D ) {
109
115
gpu_dim = 0 ; // width: x dimension in gpu
110
116
VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
111
117
} else if (dim_index == kHeight4D ) {
112
118
gpu_dim = 1 ; // height: y dimension
113
119
VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
114
- } else if (dim_index == kBatch4D ) {
115
- gpu_dim = 2 ; // batch: z dimension
116
-
117
- // Due to channel packing, each batch value is span over stride planes
118
- int64_t n_channels = dim_at (in_sizes, kChannel4D );
119
- stride = utils::div_up_4 (n_channels);
120
+ } else if (dim_index == kChannel4D ) {
121
+ gpu_dim = 2 ; // channel: z dimension
122
+ VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
123
+ in_channel_stride = dim_at (in_sizes, kChannel4D );
120
124
} else {
121
- VK_THROW (" Unexpected ncwh_dim!" );
125
+ gpu_dim = 3 ; // batch: w dimension
126
+
127
+ in_channel_stride = dim_at (in_sizes, kChannel4D );
128
+ if (packed_dim_idx == kChannel4D ) {
129
+ // Due to channel packing, each batch value is span over stride planes
130
+ in_channel_stride = utils::div_up_4 (in_channel_stride);
131
+ }
122
132
}
123
133
124
134
std::string kernel_name = " slice_batch_height_width" ;
@@ -137,7 +147,7 @@ void add_slice_tensor_copy_node(
137
147
static_cast <int32_t >(gpu_dim),
138
148
static_cast <int32_t >(start),
139
149
static_cast <int32_t >(step),
140
- static_cast <int32_t >(stride ),
150
+ static_cast <int32_t >(in_channel_stride ),
141
151
};
142
152
143
153
graph.execute_nodes ().emplace_back (new DispatchNode (
@@ -147,7 +157,8 @@ void add_slice_tensor_copy_node(
147
157
local_size,
148
158
{{out, vkapi::MemoryAccessType::WRITE},
149
159
{in, vkapi::MemoryAccessType::READ}},
150
- {t_out->sizes_ubo (), graph.create_params_buffer (params)}));
160
+ {t_out->sizes_ubo (), graph.create_params_buffer (params)},
161
+ spec_vars));
151
162
}
152
163
}
153
164
0 commit comments