Skip to content

Commit 3aad0d8

Browse files
replace hipblasLtComputeType_t with hipblasComputeType_t
1 parent 95131d6 commit 3aad0d8

33 files changed

+176
-232
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Full documentation for hipBLASLt is available at [rocm.docs.amd.com/projects/hip
2828
### Changes
2929

3030
* Replaced `hipblasDatatype_t` with `hipDataType`
31+
* Replaced `hipblasLtComputeType_t` with `hipblasComputeType_t`
3132

3233
### Removals
3334

clients/benchmarks/client.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,11 +751,11 @@ try
751751
if(arg.d_type == HIPBLASLT_DATATYPE_INVALID)
752752
throw std::invalid_argument("Invalid value for --d_type " + d_type);
753753

754-
bool is_f16 = arg.a_type == HIP_R_16F || arg.a_type == HIP_R_16BF;
755-
bool is_f32 = arg.a_type == HIP_R_32F;
756-
arg.compute_type = compute_type == "" ? (HIPBLASLT_COMPUTE_F32)
757-
: string_to_hipblaslt_computetype(compute_type);
758-
if(arg.compute_type == static_cast<hipblasLtComputeType_t>(0))
754+
bool is_f16 = arg.a_type == HIP_R_16F || arg.a_type == HIP_R_16BF;
755+
bool is_f32 = arg.a_type == HIP_R_32F;
756+
arg.compute_type
757+
= compute_type == "" ? (HIPBLAS_COMPUTE_32F) : string_to_hipblas_computetype(compute_type);
758+
if(arg.compute_type == static_cast<hipblasComputeType_t>(0))
759759
throw std::invalid_argument("Invalid value for --compute_type " + compute_type);
760760

761761
if(string_to_hip_datatype(bias_type) == HIPBLASLT_DATATYPE_INVALID && bias_type != ""

clients/benchmarks/client_groupedgemm_fixed_mk.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ int test_hipblaslt(hipDataType in_datatype,
782782
in_datatype,
783783
out_datatype,
784784
out_datatype,
785-
HIPBLASLT_COMPUTE_F32,
785+
HIPBLAS_COMPUTE_32F,
786786
heuristicResult));
787787

788788
std::vector<int> validIdx;
@@ -795,7 +795,7 @@ int test_hipblaslt(hipDataType in_datatype,
795795
in_datatype,
796796
out_datatype,
797797
out_datatype,
798-
HIPBLASLT_COMPUTE_F32);
798+
HIPBLAS_COMPUTE_32F);
799799

800800
std::cout << "index, transAB, M, N, K, lda, ldb, ldc, stride_a, stride_b, "
801801
"stride_c, batch_count, alpha, beta, bias, activationType"
@@ -844,7 +844,7 @@ int test_hipblaslt(hipDataType in_datatype,
844844
in_datatype,
845845
out_datatype,
846846
out_datatype,
847-
HIPBLASLT_COMPUTE_F32};
847+
HIPBLAS_COMPUTE_32F};
848848

849849
// step 1: set problem to {Ms, {sum of N, 1, 1, 1, ...}, Ks}
850850
CHECK_HIPBLASLT_ERROR(groupedGemm.setProblem(m,

clients/common/hipblaslt_arguments.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void Arguments::init()
104104
b_type = HIP_R_16F;
105105
c_type = HIP_R_16F;
106106
d_type = HIP_R_16F;
107-
compute_type = HIPBLASLT_COMPUTE_F32;
107+
compute_type = HIPBLAS_COMPUTE_32F;
108108
scale_type = HIP_R_32F;
109109

110110
initialization = hipblaslt_initialization::hpl;

clients/gtest/matmul_gtest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ namespace
112112
{
113113
name << hip_datatype_to_string(arg.a_type) << hip_datatype_to_string(arg.b_type)
114114
<< hip_datatype_to_string(arg.c_type) << hip_datatype_to_string(arg.d_type)
115-
<< hipblaslt_computetype_to_string(arg.compute_type);
115+
<< hipblas_computetype_to_string(arg.compute_type);
116116

117117
if(arg.activation_type != hipblaslt_activation_type::none)
118118
{

clients/include/hipblaslt_arguments.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ struct Arguments
9292
int32_t solution_index;
9393
int32_t requested_solution_num;
9494

95-
hipDataType a_type;
96-
hipDataType b_type;
97-
hipDataType c_type;
98-
hipDataType d_type;
99-
hipblasLtComputeType_t compute_type;
100-
hipDataType scale_type;
95+
hipDataType a_type;
96+
hipDataType b_type;
97+
hipDataType c_type;
98+
hipDataType d_type;
99+
hipblasComputeType_t compute_type;
100+
hipDataType scale_type;
101101

102102
hipblaslt_initialization initialization;
103103

clients/include/hipblaslt_common.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ Datatypes:
1313
bf16_r: 14
1414
f8_r: 1000
1515
bf8_r: 1001
16-
- hipblasLtComputeType_t:
16+
- hipblasComputeType_t:
1717
bases: [ c_int ]
1818
attr:
19-
c_f32_r: 300
20-
c_xf32_r: 301
21-
c_f64_r: 302
22-
c_i32_r: 303
23-
c_f32_fast_f16_r: 304
19+
c_f32_r: 2
20+
c_f32_fast_f16_r: 4
21+
c_xf32_r: 6
22+
c_f64_r: 7
23+
c_i32_r: 9
2424
- { half: f16_r }
2525
- hipblaslt_initialization:
2626
bases: [ c_int ]
@@ -172,7 +172,7 @@ Arguments:
172172
- b_type: hipDataType
173173
- c_type: hipDataType
174174
- d_type: hipDataType
175-
- compute_type: hipblasLtComputeType_t
175+
- compute_type: hipblasComputeType_t
176176
- scale_type: hipDataType
177177
- initialization: hipblaslt_initialization
178178
- gpu_arch: c_char*4

clients/include/hipblaslt_test.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ struct hipblaslt_test_invalid
433433
<< " b: " << hip_datatype_to_string(arg.b_type)
434434
<< " c: " << hip_datatype_to_string(arg.c_type)
435435
<< " d: " << hip_datatype_to_string(arg.d_type)
436-
<< " compute:" << hipblaslt_computetype_to_string(arg.compute_type)
436+
<< " compute:" << hipblas_computetype_to_string(arg.compute_type)
437437
<< std::endl;
438438
hipblaslt_abort();
439439
#endif

clients/include/testing_matmul.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1782,7 +1782,7 @@ void testing_matmul(const Arguments& arg)
17821782
// For the xf32 xdl math op, cast type of A/B from float to xfloat32 .
17831783
if constexpr(std::is_same<TiA, float>{} && std::is_same<TiB, float>{}
17841784
&& std::is_same<To, float>{} && std::is_same<Tc, float>{})
1785-
if(arg.compute_type == HIPBLASLT_COMPUTE_F32_FAST_XF32)
1785+
if(arg.compute_type == HIPBLAS_COMPUTE_32F_FAST_TF32)
17861786
{
17871787
for(int i = 0; i < gemm_count; i++)
17881788
{

clients/include/type_dispatch.hpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -97,85 +97,85 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
9797

9898
if(arg.d_type == To)
9999
{
100-
if(TiA == To && TiB == To && To == HIP_R_16F && Tc == HIPBLASLT_COMPUTE_F32)
100+
if(TiA == To && TiB == To && To == HIP_R_16F && Tc == HIPBLAS_COMPUTE_32F)
101101
{
102102
return TEST<hipblasLtHalf, hipblasLtHalf, hipblasLtHalf, float>{}(arg);
103103
}
104-
else if(TiA == To && TiB == To && To == HIP_R_16BF && Tc == HIPBLASLT_COMPUTE_F32)
104+
else if(TiA == To && TiB == To && To == HIP_R_16BF && Tc == HIPBLAS_COMPUTE_32F)
105105
{
106106
return TEST<hip_bfloat16, hip_bfloat16, hip_bfloat16, float>{}(arg);
107107
}
108108
else if(TiA == To && TiB == To && To == HIP_R_32F
109-
&& (Tc == HIPBLASLT_COMPUTE_F32 || Tc == HIPBLASLT_COMPUTE_F32_FAST_XF32))
109+
&& (Tc == HIPBLAS_COMPUTE_32F || Tc == HIPBLAS_COMPUTE_32F_FAST_TF32))
110110
{
111111
return TEST<float, float, float, float>{}(arg);
112112
}
113-
else if(TiA == To && TiB == To && To == HIP_R_64F && (Tc == HIPBLASLT_COMPUTE_F64))
113+
else if(TiA == To && TiB == To && To == HIP_R_64F && (Tc == HIPBLAS_COMPUTE_64F))
114114
{
115115
return TEST<double, double, double, double>{}(arg);
116116
}
117117
else if(TiA == HIP_R_16F && TiB == HIP_R_16F && To == HIP_R_32F
118-
&& Tc == HIPBLASLT_COMPUTE_F32)
118+
&& Tc == HIPBLAS_COMPUTE_32F)
119119
{
120120
return TEST<hipblasLtHalf, hipblasLtHalf, float, float>{}(arg);
121121
}
122122
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
123-
&& Tc == HIPBLASLT_COMPUTE_F32)
123+
&& Tc == HIPBLAS_COMPUTE_32F)
124124
{
125125
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, float, float>{}(arg);
126126
}
127127
else if(TiA == HIP_R_8F_E5M2_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
128-
&& Tc == HIPBLASLT_COMPUTE_F32)
128+
&& Tc == HIPBLAS_COMPUTE_32F)
129129
{
130130
return TEST<hipblaslt_bf8_fnuz, hipblaslt_f8_fnuz, float, float>{}(arg);
131131
}
132132
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_32F
133-
&& Tc == HIPBLASLT_COMPUTE_F32)
133+
&& Tc == HIPBLAS_COMPUTE_32F)
134134
{
135135
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, float, float>{}(arg);
136136
}
137137
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
138-
&& Tc == HIPBLASLT_COMPUTE_F32)
138+
&& Tc == HIPBLAS_COMPUTE_32F)
139139
{
140140
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float>{}(arg);
141141
}
142142
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16BF
143-
&& Tc == HIPBLASLT_COMPUTE_F32)
143+
&& Tc == HIPBLAS_COMPUTE_32F)
144144
{
145145
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtBfloat16, float>{}(arg);
146146
}
147147
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
148-
&& Tc == HIPBLASLT_COMPUTE_F32)
148+
&& Tc == HIPBLAS_COMPUTE_32F)
149149
{
150150
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, float>{}(arg);
151151
}
152152
else if(TiA == HIP_R_8F_E5M2_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
153-
&& Tc == HIPBLASLT_COMPUTE_F32)
153+
&& Tc == HIPBLAS_COMPUTE_32F)
154154
{
155155
return TEST<hipblaslt_bf8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float>{}(arg);
156156
}
157157
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_16F
158-
&& Tc == HIPBLASLT_COMPUTE_F32)
158+
&& Tc == HIPBLAS_COMPUTE_32F)
159159
{
160160
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, hipblasLtHalf, float>{}(arg);
161161
}
162162
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_16BF
163-
&& Tc == HIPBLASLT_COMPUTE_F32)
163+
&& Tc == HIPBLAS_COMPUTE_32F)
164164
{
165165
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, hipblasLtBfloat16, float>{}(arg);
166166
}
167167
/*
168-
else if(Ti == HIP_R_8I && To == HIP_R_8I && Tc == HIPBLASLT_COMPUTE_I32)
168+
else if(Ti == HIP_R_8I && To == HIP_R_8I && Tc == HIPBLAS_COMPUTE_32I)
169169
{
170170
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t>{}(arg);
171171
}
172172
*/
173-
else if(TiA == HIP_R_8I && To == HIP_R_32I && Tc == HIPBLASLT_COMPUTE_I32)
173+
else if(TiA == HIP_R_8I && To == HIP_R_32I && Tc == HIPBLAS_COMPUTE_32I)
174174
{
175175
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t, int32_t>{}(arg);
176176
}
177177
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_8F_E4M3_FNUZ
178-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
178+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
179179
{
180180
return TEST<hipblaslt_f8_fnuz,
181181
hipblasLtHalf,
@@ -184,7 +184,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
184184
hipblasLtHalf>{}(arg);
185185
}
186186
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
187-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
187+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
188188
{
189189
return TEST<hipblasLtHalf,
190190
hipblaslt_f8_fnuz,
@@ -193,29 +193,29 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
193193
hipblasLtHalf>{}(arg);
194194
}
195195
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_16F
196-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
196+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
197197
{
198198
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, hipblasLtHalf, float, hipblasLtHalf>{}(
199199
arg);
200200
}
201201
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
202-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
202+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
203203
{
204204
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, hipblasLtHalf, float, hipblasLtHalf>{}(
205205
arg);
206206
}
207207
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_32F
208-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
208+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
209209
{
210210
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, float, float, hipblasLtHalf>{}(arg);
211211
}
212212
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
213-
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
213+
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
214214
{
215215
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, float, float, hipblasLtHalf>{}(arg);
216216
}
217217
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_8F_E4M3_FNUZ
218-
&& Tc == HIPBLASLT_COMPUTE_F32)
218+
&& Tc == HIPBLAS_COMPUTE_32F)
219219
{
220220
return TEST<hipblaslt_f8_fnuz,
221221
hipblasLtHalf,
@@ -224,7 +224,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
224224
hipblaslt_f8_fnuz>{}(arg);
225225
}
226226
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
227-
&& Tc == HIPBLASLT_COMPUTE_F32)
227+
&& Tc == HIPBLAS_COMPUTE_32F)
228228
{
229229
return TEST<hipblasLtHalf,
230230
hipblaslt_f8_fnuz,
@@ -233,7 +233,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
233233
hipblaslt_f8_fnuz>{}(arg);
234234
}
235235
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_16F
236-
&& Tc == HIPBLASLT_COMPUTE_F32)
236+
&& Tc == HIPBLAS_COMPUTE_32F)
237237
{
238238
return TEST<hipblaslt_f8_fnuz,
239239
hipblasLtHalf,
@@ -242,7 +242,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
242242
hipblaslt_f8_fnuz>{}(arg);
243243
}
244244
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
245-
&& Tc == HIPBLASLT_COMPUTE_F32)
245+
&& Tc == HIPBLAS_COMPUTE_32F)
246246
{
247247
return TEST<hipblasLtHalf,
248248
hipblaslt_f8_fnuz,
@@ -251,12 +251,12 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
251251
hipblaslt_f8_fnuz>{}(arg);
252252
}
253253
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_32F
254-
&& Tc == HIPBLASLT_COMPUTE_F32)
254+
&& Tc == HIPBLAS_COMPUTE_32F)
255255
{
256256
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, float, float, hipblaslt_f8_fnuz>{}(arg);
257257
}
258258
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
259-
&& Tc == HIPBLASLT_COMPUTE_F32)
259+
&& Tc == HIPBLAS_COMPUTE_32F)
260260
{
261261
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, float, float, hipblaslt_f8_fnuz>{}(arg);
262262
}

clients/include/utility.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ class hipblaslt_local_matmul_descr
191191
hipblasStatus_t m_status = HIPBLAS_STATUS_NOT_INITIALIZED;
192192

193193
public:
194-
hipblaslt_local_matmul_descr(hipblasOperation_t opA,
195-
hipblasOperation_t opB,
196-
hipblasLtComputeType_t compute_type,
197-
hipDataType scale_type)
194+
hipblaslt_local_matmul_descr(hipblasOperation_t opA,
195+
hipblasOperation_t opB,
196+
hipblasComputeType_t compute_type,
197+
hipDataType scale_type)
198198
{
199199
this->m_status = hipblasLtMatmulDescCreate(&this->m_descr, compute_type, scale_type);
200200

clients/samples/gemm/sample_hipblaslt_gemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void simpleGemm(hipblasLtHandle_t handle,
127127
}
128128

129129
hipblasLtMatmulDesc_t matmul;
130-
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIP_R_32F));
130+
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
131131
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
132132
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
133133
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(

clients/samples/gemm/sample_hipblaslt_gemm_ext.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,8 @@ void simpleGemmExt(hipblasLtHandle_t handle,
9999
{
100100
hipblaslt_ext::GemmPreference gemmPref;
101101
gemmPref.setMaxWorkspaceBytes(max_workspace_size);
102-
hipblaslt_ext::Gemm gemm(handle,
103-
trans_a,
104-
trans_b,
105-
HIP_R_16F,
106-
HIP_R_16F,
107-
HIP_R_16F,
108-
HIP_R_16F,
109-
HIPBLASLT_COMPUTE_F32);
102+
hipblaslt_ext::Gemm gemm(
103+
handle, trans_a, trans_b, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIPBLAS_COMPUTE_32F);
110104

111105
hipblaslt_ext::GemmEpilogue
112106
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)

clients/samples/gemm_alphavec/sample_hipblaslt_gemm_alphavec_ext.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,8 @@ void simpleGemmAlphaVecExt(hipblasLtHandle_t handle,
102102
{
103103
hipblaslt_ext::GemmPreference gemmPref;
104104
gemmPref.setMaxWorkspaceBytes(max_workspace_size);
105-
hipblaslt_ext::Gemm gemm(handle,
106-
trans_a,
107-
trans_b,
108-
HIP_R_16F,
109-
HIP_R_16F,
110-
HIP_R_16F,
111-
HIP_R_16F,
112-
HIPBLASLT_COMPUTE_F32);
105+
hipblaslt_ext::Gemm gemm(
106+
handle, trans_a, trans_b, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIPBLAS_COMPUTE_32F);
113107

114108
hipblaslt_ext::GemmEpilogue
115109
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)

0 commit comments

Comments
 (0)