@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
341
341
return %0 : vector <3 x2 xf32 >
342
342
}
343
343
344
+ // CHECK-LABEL: @masked_matvec_mk_k_m
345
+ // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
346
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
347
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
348
+ // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
349
+ func.func @masked_matvec_mk_k_m (%arg0: vector <4 x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <4 x2 xi1 >) -> vector <4 xf32 > {
350
+ // CHECK: vector.transpose %[[MASK]]
351
+ // CHECK: vector.transpose %[[MAT]]
352
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
353
+ %res = vector.mask %mask {
354
+ vector.contract {
355
+ indexing_maps = [affine_map <(m , k ) -> (m , k )>,
356
+ affine_map <(m , k ) -> (k )>,
357
+ affine_map <(m , k ) -> (m )>],
358
+ iterator_types = [" parallel" , " reduction" ],
359
+ kind = #vector.kind <add >
360
+ } %arg0 , %arg1 , %arg2 : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
361
+ } : vector <4 x2 xi1 > -> vector <4 xf32 >
362
+ return %res : vector <4 xf32 >
363
+ }
364
+
365
+ // CHECK-LABEL: @masked_matvec_km_k_m
366
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
367
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
368
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
369
+ // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
370
+ func.func @masked_matvec_km_k_m (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <4 x2 xi1 >) -> vector <4 xf32 > {
371
+ // CHECK: vector.transpose %[[MASK]]
372
+ // CHECK-NOT: vector.transpose %[[MAT]]
373
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
374
+ %res = vector.mask %mask {
375
+ vector.contract {
376
+ indexing_maps = [affine_map <(m , k ) -> (k , m )>,
377
+ affine_map <(m , k ) -> (k )>,
378
+ affine_map <(m , k ) -> (m )>],
379
+ iterator_types = [" parallel" , " reduction" ],
380
+ kind = #vector.kind <add >
381
+ } %arg0 , %arg1 , %arg2 : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
382
+ } : vector <4 x2 xi1 > -> vector <4 xf32 >
383
+ return %res : vector <4 xf32 >
384
+ }
385
+
386
+ // CHECK-LABEL: @masked_matvec_k_mk_m
387
+ // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
388
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
389
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
390
+ // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
391
+ func.func @masked_matvec_k_mk_m (%arg0: vector <4 x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <4 x2 xi1 >) -> vector <4 xf32 > {
392
+ // CHECK: vector.transpose %[[MASK]]
393
+ // CHECK: vector.transpose %[[MAT]]
394
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
395
+ %res = vector.mask %mask {
396
+ vector.contract {
397
+ indexing_maps = [affine_map <(m , k ) -> (k )>,
398
+ affine_map <(m , k ) -> (m , k )>,
399
+ affine_map <(m , k ) -> (m )>],
400
+ iterator_types = [" parallel" , " reduction" ],
401
+ kind = #vector.kind <add >
402
+ } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
403
+ } : vector <4 x2 xi1 > -> vector <4 xf32 >
404
+ return %res : vector <4 xf32 >
405
+ }
406
+
407
+ // CHECK-LABEL: @masked_matvec_k_km_m
408
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
409
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
410
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
411
+ // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
412
+ func.func @masked_matvec_k_km_m (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <4 x2 xi1 >) -> vector <4 xf32 > {
413
+ // CHECK: vector.transpose %[[MASK]]
414
+ // CHECK-NOT: vector.transpose %[[MAT]]
415
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
416
+ %res = vector.mask %mask {
417
+ vector.contract {
418
+ indexing_maps = [affine_map <(m , k ) -> (k )>,
419
+ affine_map <(m , k ) -> (k , m )>,
420
+ affine_map <(m , k ) -> (m )>],
421
+ iterator_types = [" parallel" , " reduction" ],
422
+ kind = #vector.kind <add >
423
+ } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
424
+ } : vector <4 x2 xi1 > -> vector <4 xf32 >
425
+ return %res : vector <4 xf32 >
426
+ }
427
+
428
+ // CHECK-LABEL: @masked_tmatvec_mk_k_m
429
+ // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
430
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
431
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
432
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
433
+ func.func @masked_tmatvec_mk_k_m (%arg0: vector <4 x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <2 x4 xi1 >) -> vector <4 xf32 > {
434
+ // CHECK: vector.transpose %[[MAT]]
435
+ // CHECK-NOT: vector.transpose %[[MASK]]
436
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
437
+ %res = vector.mask %mask {
438
+ vector.contract {
439
+ indexing_maps = [affine_map <(k , m ) -> (m , k )>,
440
+ affine_map <(k , m ) -> (k )>,
441
+ affine_map <(k , m ) -> (m )>],
442
+ iterator_types = [" reduction" , " parallel" ],
443
+ kind = #vector.kind <add >
444
+ } %arg0 , %arg1 , %arg2 : vector <4 x2 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
445
+ } : vector <2 x4 xi1 > -> vector <4 xf32 >
446
+ return %res : vector <4 xf32 >
447
+ }
448
+
449
+ // CHECK-LABEL: @masked_tmatvec_km_k_m
450
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
451
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
452
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
453
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
454
+ func.func @masked_tmatvec_km_k_m (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <2 x4 xi1 >) -> vector <4 xf32 > {
455
+ // CHECK-NOT: vector.transpose %[[MAT]]
456
+ // CHECK-NOT: vector.transpose %[[MASK]]
457
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
458
+ %res = vector.mask %mask {
459
+ vector.contract {
460
+ indexing_maps = [affine_map <(k , m ) -> (k , m )>,
461
+ affine_map <(k , m ) -> (k )>,
462
+ affine_map <(k , m ) -> (m )>],
463
+ iterator_types = [" reduction" , " parallel" ],
464
+ kind = #vector.kind <add >
465
+ } %arg0 , %arg1 , %arg2 : vector <2 x4 xf32 >, vector <2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
466
+ } : vector <2 x4 xi1 > -> vector <4 xf32 >
467
+ return %res : vector <4 xf32 >
468
+ }
469
+
470
+ // CHECK-LABEL: @masked_tmatvec_k_mk_m
471
+ // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
472
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
473
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
474
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
475
+ func.func @masked_tmatvec_k_mk_m (%arg0: vector <4 x2 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <2 x4 xi1 >) -> vector <4 xf32 > {
476
+ // CHECK: vector.transpose %[[MAT]]
477
+ // CHECK-NOT: vector.transpose %[[MASK]]
478
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
479
+ %res = vector.mask %mask {
480
+ vector.contract {
481
+ indexing_maps = [affine_map <(k , m ) -> (k )>,
482
+ affine_map <(k , m ) -> (m , k )>,
483
+ affine_map <(k , m ) -> (m )>],
484
+ iterator_types = [" reduction" , " parallel" ],
485
+ kind = #vector.kind <add >
486
+ } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <4 x2 xf32 >, vector <4 xf32 > into vector <4 xf32 >
487
+ } : vector <2 x4 xi1 > -> vector <4 xf32 >
488
+ return %res : vector <4 xf32 >
489
+ }
490
+
491
+ // CHECK-LABEL: @masked_tmatvec_k_km_m
492
+ // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
493
+ // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
494
+ // CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
495
+ // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
496
+ func.func @masked_tmatvec_k_km_m (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 xf32 >, %arg2: vector <4 xf32 >, %mask: vector <2 x4 xi1 >) -> vector <4 xf32 > {
497
+ // CHECK-NOT: vector.transpose %[[MAT]]
498
+ // CHECK-NOT: vector.transpose %[[MASK]]
499
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
500
+ %res = vector.mask %mask {
501
+ vector.contract {
502
+ indexing_maps = [affine_map <(k , m ) -> (k )>,
503
+ affine_map <(k , m ) -> (k , m )>,
504
+ affine_map <(k , m ) -> (m )>],
505
+ iterator_types = [" reduction" , " parallel" ],
506
+ kind = #vector.kind <add >
507
+ } %arg1 , %arg0 , %arg2 : vector <2 xf32 >, vector <2 x4 xf32 >, vector <4 xf32 > into vector <4 xf32 >
508
+ } : vector <2 x4 xi1 > -> vector <4 xf32 >
509
+ return %res : vector <4 xf32 >
510
+ }
511
+
344
512
345
513
transform.sequence failures (propagate ) {
346
514
^bb1 (%module_op: !transform.any_op ):
0 commit comments