@@ -81,15 +81,14 @@ ValueRef prepack_biases(
81
81
ComputeGraph& graph,
82
82
const ValueRef vref,
83
83
const ValueRef weight,
84
- const bool transposed) {
84
+ const bool transposed,
85
+ const api::StorageType storage_type,
86
+ const api::GPUMemoryLayout memory_layout) {
85
87
auto sizes = graph.get_sizes_of (weight);
86
88
const int64_t out_channels = transposed ? sizes.at (1 ) : sizes.at (0 );
87
89
88
90
ValueRef v = graph.add_tensor (
89
- {out_channels},
90
- graph.get_dtype_of (weight),
91
- api::kTexture2D ,
92
- api::kWidthPacked );
91
+ {out_channels}, graph.get_dtype_of (weight), storage_type, memory_layout);
93
92
vTensorPtr t = graph.get_tensor (v);
94
93
95
94
api::ShaderInfo shader = get_nchw_to_image_shader (*t);
@@ -329,7 +328,13 @@ void add_conv2d_node(
329
328
330
329
ValueRef arg_in = prepack_if_tensor_ref (graph, in);
331
330
ValueRef arg_weight = prepack_weights (graph, weight, method);
332
- ValueRef arg_bias = prepack_biases (graph, bias, weight, transposed_val);
331
+ ValueRef arg_bias = prepack_biases (
332
+ graph,
333
+ bias,
334
+ weight,
335
+ transposed_val,
336
+ /* storage_type = */ api::kTexture2D ,
337
+ /* memory_layout = */ api::kWidthPacked );
333
338
334
339
vTensorPtr t_in = graph.get_tensor (arg_in);
335
340
vTensorPtr t_out = graph.get_tensor (out);
@@ -383,15 +388,16 @@ void add_conv1d_node(
383
388
const ValueRef dilation,
384
389
const ValueRef groups,
385
390
const ValueRef out) {
386
- if (graph.val_is_none (bias)) {
387
- VK_THROW (" conv1d: Null bias is not supported yet!" );
388
- }
389
-
390
391
ValueRef arg_in = prepack_if_tensor_ref (graph, in);
391
392
ValueRef arg_weight =
392
393
prepack_if_tensor_ref (graph, weight, graph.memory_layout_of (arg_in));
393
- ValueRef arg_bias =
394
- prepack_if_tensor_ref (graph, bias, graph.memory_layout_of (arg_in));
394
+ ValueRef arg_bias = prepack_biases (
395
+ graph,
396
+ bias,
397
+ weight,
398
+ /* transposed = */ false ,
399
+ /* storage_type = */ api::kTexture3D ,
400
+ /* memory_layout = */ api::kChannelsPacked );
395
401
396
402
vTensorPtr t_in = graph.get_tensor (arg_in);
397
403
vTensorPtr t_weight = graph.get_tensor (arg_weight);
0 commit comments