Skip to content

Commit 276b9cf

Browse files
authored
[test-e2e][Matrix] Fix joint matrix load address (#10259)
The loading address didn't calculate the column offset right w.r.t the sub-group id. Signed-off-by: Yilong Guo <[email protected]>
1 parent 00533e0 commit 276b9cf

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ void matrix_sum_rows(queue q, big_matrix<T, M, K> &A, nd_range<2> &r) {
117117
sub_a;
118118

119119
joint_matrix_load(
120-
sg, sub_a, accA.template get_multi_ptr<access::decorated::no>() + (global_idx * TM * K) + TK,
120+
sg, sub_a, accA.template get_multi_ptr<access::decorated::no>() + (sg_startx * TM * K) + sg_starty / SG_SZ * TK,
121121
K);
122122

123123
// calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_a
@@ -175,7 +175,7 @@ int main() {
175175

176176
for (int i = 0; i < MATRIX_M; i++) {
177177
for (int j = 0; j < MATRIX_K; j++) {
178-
A[i][j] = i;
178+
A[i][j] = i + j;
179179
}
180180
}
181181

sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void matrix_sum_cols(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
142142
joint_matrix_load(
143143
sg, sub_b,
144144
accB.template get_multi_ptr<access::decorated::no>() +
145-
(global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4,
145+
(sg_startx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4,
146146
N);
147147

148148
int32_t sum_local_cols[N] = {0}; // 4 local cols, N total
@@ -207,7 +207,7 @@ int main() {
207207

208208
for (int i = 0; i < MATRIX_K; i++) {
209209
for (int j = 0; j < MATRIX_N; j++) {
210-
B[i][j] = i;
210+
B[i][j] = i + j;
211211
}
212212
}
213213

0 commit comments

Comments
 (0)