@@ -296,35 +296,39 @@ def quantized_matmul(
296
296
297
297
298
298
@linalg_structured_op
299
- def matmul_transpose_a (A = TensorDef (T1 , S .K , S .N ),
300
- B = TensorDef (T2 , S .K , S .M ),
301
- C = TensorDef (U , S .M , S .N , output = True ),
302
- cast = TypeFnAttrDef (default = TypeFn .cast_signed )):
303
- """Performs a matrix multiplication of two 2D inputs with lhs operand
304
- transposed.
299
+ def matmul_transpose_a (
300
+ A = TensorDef (T1 , S .K , S .N ),
301
+ B = TensorDef (T2 , S .K , S .M ),
302
+ C = TensorDef (U , S .M , S .N , output = True ),
303
+ cast = TypeFnAttrDef (default = TypeFn .cast_signed ),
304
+ ):
305
+ """Performs a matrix multiplication of two 2D inputs with lhs operand
306
+ transposed.
305
307
306
- Numeric casting is performed on the operands to the inner multiply, promoting
307
- them to the same data type as the accumulator/output.
308
- """
309
- domain (D .m , D .n , D .k )
310
- implements (ContractionOpInterface )
311
- C [D .m , D .n ] += cast (U , A [D .k , D .m ]) * cast (U , B [D .k , D .n ])
308
+ Numeric casting is performed on the operands to the inner multiply, promoting
309
+ them to the same data type as the accumulator/output.
310
+ """
311
+ domain (D .m , D .n , D .k )
312
+ implements (ContractionOpInterface )
313
+ C [D .m , D .n ] += cast (U , A [D .k , D .m ]) * cast (U , B [D .k , D .n ])
312
314
313
315
314
316
@linalg_structured_op
315
- def matmul_transpose_b (A = TensorDef (T1 , S .M , S .K ),
316
- B = TensorDef (T2 , S .N , S .K ),
317
- C = TensorDef (U , S .M , S .N , output = True ),
318
- cast = TypeFnAttrDef (default = TypeFn .cast_signed )):
319
- """Performs a matrix multiplication of two 2D inputs with rhs operand
320
- transposed.
317
+ def matmul_transpose_b (
318
+ A = TensorDef (T1 , S .M , S .K ),
319
+ B = TensorDef (T2 , S .N , S .K ),
320
+ C = TensorDef (U , S .M , S .N , output = True ),
321
+ cast = TypeFnAttrDef (default = TypeFn .cast_signed ),
322
+ ):
323
+ """Performs a matrix multiplication of two 2D inputs with rhs operand
324
+ transposed.
321
325
322
- Numeric casting is performed on the operands to the inner multiply, promoting
323
- them to the same data type as the accumulator/output.
324
- """
325
- domain (D .m , D .n , D .k )
326
- implements (ContractionOpInterface )
327
- C [D .m , D .n ] += cast (U , A [D .m , D .k ]) * cast (U , B [D .n , D .k ])
326
+ Numeric casting is performed on the operands to the inner multiply, promoting
327
+ them to the same data type as the accumulator/output.
328
+ """
329
+ domain (D .m , D .n , D .k )
330
+ implements (ContractionOpInterface )
331
+ C [D .m , D .n ] += cast (U , A [D .m , D .k ]) * cast (U , B [D .n , D .k ])
328
332
329
333
330
334
@linalg_structured_op
@@ -390,36 +394,41 @@ def batch_matmul(
390
394
391
395
392
396
@linalg_structured_op
393
- def batch_matmul_transpose_a (A = TensorDef (T1 , Batch , S .K , S .M ),
394
- B = TensorDef (T2 , Batch , S .K , S .N ),
395
- C = TensorDef (U , Batch , S .M , S .N , output = True )):
396
- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
397
- has its non-batch dimensions transposed.
397
+ def batch_matmul_transpose_a (
398
+ A = TensorDef (T1 , Batch , S .K , S .M ),
399
+ B = TensorDef (T2 , Batch , S .K , S .N ),
400
+ C = TensorDef (U , Batch , S .M , S .N , output = True ),
401
+ ):
402
+ """Performs a batched matrix multiplication of two 3D inputs where lhs operand
403
+ has its non-batch dimensions transposed.
398
404
399
- Numeric casting is performed on the operands to the inner multiply, promoting
400
- them to the same data type as the accumulator/output.
401
- """
402
- domain (D .b , D .m , D .n , D .k )
403
- implements (ContractionOpInterface )
404
- C [D .b , D .m , D .n ] += TypeFn .cast_signed (U , A [D .b , D .k , D .m ]) \
405
- * TypeFn .cast_signed (U , B [D .b , D .k , D .n ])
405
+ Numeric casting is performed on the operands to the inner multiply, promoting
406
+ them to the same data type as the accumulator/output.
407
+ """
408
+ domain (D .b , D .m , D .n , D .k )
409
+ implements (ContractionOpInterface )
410
+ C [D .b , D .m , D .n ] += TypeFn .cast_signed (U , A [D .b , D .k , D .m ]) * TypeFn .cast_signed (
411
+ U , B [D .b , D .k , D .n ]
412
+ )
406
413
407
414
408
415
@linalg_structured_op
409
- def batch_matmul_transpose_b (A = TensorDef (T1 , Batch , S .M , S .K ),
410
- B = TensorDef (T2 , Batch , S .N , S .K ),
411
- C = TensorDef (U , Batch , S .M , S .N , output = True )):
412
- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
413
- has its non-batch dimensions transposed.
416
+ def batch_matmul_transpose_b (
417
+ A = TensorDef (T1 , Batch , S .M , S .K ),
418
+ B = TensorDef (T2 , Batch , S .N , S .K ),
419
+ C = TensorDef (U , Batch , S .M , S .N , output = True ),
420
+ ):
421
+ """Performs a batched matrix multiplication of two 3D inputs where rhs operand
422
+ has its non-batch dimensions transposed.
414
423
415
- Numeric casting is performed on the operands to the inner multiply, promoting
416
- them to the same data type as the accumulator/output.
417
- """
418
- domain (D .b , D .m , D .n , D .k )
419
- implements (ContractionOpInterface )
420
- C [D .b , D .m ,
421
- D . n ] += TypeFn . cast_signed ( U , A [D .b , D .m , D .k ]) * TypeFn . cast_signed (
422
- U , B [ D . b , D . n , D . k ] )
424
+ Numeric casting is performed on the operands to the inner multiply, promoting
425
+ them to the same data type as the accumulator/output.
426
+ """
427
+ domain (D .b , D .m , D .n , D .k )
428
+ implements (ContractionOpInterface )
429
+ C [D .b , D .m , D . n ] += TypeFn . cast_signed ( U , A [ D . b , D . m , D . k ]) * TypeFn . cast_signed (
430
+ U , B [D .b , D .n , D .k ]
431
+ )
423
432
424
433
425
434
@linalg_structured_op
0 commit comments